quantrs2_core/
quantum_autodiff.rs

1//! Enhanced Automatic Differentiation for Quantum Gradients
2//!
3//! This module provides advanced automatic differentiation capabilities
4//! specifically designed for quantum computing, including parameter-shift
5//! rules, finite differences, and hybrid classical-quantum gradients.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::Complex64;
9use std::{
10    collections::HashMap,
11    fmt,
12    sync::{Arc, RwLock},
13};
14
15/// Configuration for quantum automatic differentiation
16#[derive(Debug, Clone)]
17pub struct QuantumAutoDiffConfig {
18    /// Default differentiation method
19    pub default_method: DifferentiationMethod,
20    /// Finite difference step size
21    pub finite_diff_step: f64,
22    /// Parameter-shift rule step size
23    pub parameter_shift_step: f64,
24    /// Enable higher-order derivatives
25    pub enable_higher_order: bool,
26    /// Maximum order of derivatives to compute
27    pub max_derivative_order: usize,
28    /// Gradient computation precision
29    pub gradient_precision: f64,
30    /// Enable gradient caching
31    pub enable_caching: bool,
32    /// Cache size limit
33    pub cache_size_limit: usize,
34}
35
36impl Default for QuantumAutoDiffConfig {
37    fn default() -> Self {
38        Self {
39            default_method: DifferentiationMethod::ParameterShift,
40            finite_diff_step: 1e-7,
41            parameter_shift_step: std::f64::consts::PI / 2.0,
42            enable_higher_order: true,
43            max_derivative_order: 3,
44            gradient_precision: 1e-12,
45            enable_caching: true,
46            cache_size_limit: 10000,
47        }
48    }
49}
50
51/// Methods for computing quantum gradients
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DifferentiationMethod {
54    /// Parameter-shift rule (exact for many quantum gates)
55    ParameterShift,
56    /// Finite differences (numerical approximation)
57    FiniteDifference,
58    /// Central differences (more accurate numerical)
59    CentralDifference,
60    /// Complex step differentiation
61    ComplexStep,
62    /// Automatic differentiation using dual numbers
63    DualNumber,
64    /// Hybrid method (automatic selection)
65    Hybrid,
66}
67
68impl fmt::Display for DifferentiationMethod {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            Self::ParameterShift => write!(f, "Parameter-Shift"),
72            Self::FiniteDifference => write!(f, "Finite Difference"),
73            Self::CentralDifference => write!(f, "Central Difference"),
74            Self::ComplexStep => write!(f, "Complex Step"),
75            Self::DualNumber => write!(f, "Dual Number"),
76            Self::Hybrid => write!(f, "Hybrid"),
77        }
78    }
79}
80
81/// Quantum automatic differentiation engine
82#[derive(Debug)]
83pub struct QuantumAutoDiff {
84    config: QuantumAutoDiffConfig,
85    gradient_cache: Arc<RwLock<GradientCache>>,
86    computation_graph: Arc<RwLock<ComputationGraph>>,
87    parameter_registry: Arc<RwLock<ParameterRegistry>>,
88}
89
90/// Cache for computed gradients
91#[derive(Debug)]
92pub struct GradientCache {
93    entries: HashMap<String, CacheEntry>,
94    access_order: Vec<String>,
95    total_size: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct CacheEntry {
100    gradients: Vec<Complex64>,
101    computation_cost: f64,
102    timestamp: std::time::Instant,
103    method_used: DifferentiationMethod,
104}
105
106/// Computation graph for tracking quantum operations
107#[derive(Debug)]
108pub struct ComputationGraph {
109    nodes: Vec<ComputationNode>,
110    edges: Vec<ComputationEdge>,
111    parameter_dependencies: HashMap<usize, Vec<usize>>,
112}
113
114#[derive(Debug, Clone)]
115pub struct ComputationNode {
116    id: usize,
117    operation: QuantumOperation,
118    inputs: Vec<usize>,
119    outputs: Vec<usize>,
120}
121
122#[derive(Debug, Clone)]
123pub struct ComputationEdge {
124    from: usize,
125    to: usize,
126    parameter_id: Option<usize>,
127}
128
129#[derive(Debug, Clone)]
130pub enum QuantumOperation {
131    Gate { name: String, parameters: Vec<f64> },
132    Measurement { observable: String },
133    StatePreparation { amplitudes: Vec<Complex64> },
134    Expectation { observable: String },
135}
136
137/// Registry for tracking parameters
138#[derive(Debug)]
139pub struct ParameterRegistry {
140    parameters: HashMap<usize, Parameter>,
141    next_id: usize,
142}
143
144#[derive(Debug, Clone)]
145pub struct Parameter {
146    id: usize,
147    name: String,
148    value: f64,
149    bounds: Option<(f64, f64)>,
150    differentiable: bool,
151    gradient_method: Option<DifferentiationMethod>,
152}
153
154/// Result of gradient computation
155#[derive(Debug, Clone)]
156pub struct GradientResult {
157    pub gradients: Vec<Complex64>,
158    pub parameter_ids: Vec<usize>,
159    pub computation_method: DifferentiationMethod,
160    pub computation_time: std::time::Duration,
161    pub numerical_error_estimate: f64,
162}
163
164/// Higher-order derivative result
165#[derive(Debug, Clone)]
166pub struct HigherOrderResult {
167    pub derivatives: Vec<Vec<Complex64>>, // derivatives[order][parameter]
168    pub parameter_ids: Vec<usize>,
169    pub orders: Vec<usize>,
170    pub mixed_derivatives: HashMap<(usize, usize), Complex64>,
171}
172
173impl QuantumAutoDiff {
174    /// Create a new quantum automatic differentiation engine
175    pub fn new(config: QuantumAutoDiffConfig) -> Self {
176        Self {
177            config,
178            gradient_cache: Arc::new(RwLock::new(GradientCache::new())),
179            computation_graph: Arc::new(RwLock::new(ComputationGraph::new())),
180            parameter_registry: Arc::new(RwLock::new(ParameterRegistry::new())),
181        }
182    }
183
184    /// Register a parameter for differentiation
185    pub fn register_parameter(
186        &mut self,
187        name: &str,
188        initial_value: f64,
189        bounds: Option<(f64, f64)>,
190    ) -> QuantRS2Result<usize> {
191        let mut registry = self
192            .parameter_registry
193            .write()
194            .expect("Parameter registry lock poisoned during registration");
195        Ok(registry.add_parameter(name, initial_value, bounds))
196    }
197
198    /// Compute gradients using the specified method
199    pub fn compute_gradients<F>(
200        &mut self,
201        function: F,
202        parameter_ids: &[usize],
203        method: Option<DifferentiationMethod>,
204    ) -> QuantRS2Result<GradientResult>
205    where
206        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
207    {
208        let method = method.unwrap_or(self.config.default_method);
209        let start_time = std::time::Instant::now();
210
211        // Check cache first
212        if self.config.enable_caching {
213            let cache_key = self.generate_cache_key(parameter_ids, method);
214            if let Some(cached) = self.get_cached_gradient(&cache_key) {
215                return Ok(GradientResult {
216                    gradients: cached.gradients,
217                    parameter_ids: parameter_ids.to_vec(),
218                    computation_method: cached.method_used,
219                    computation_time: start_time.elapsed(),
220                    numerical_error_estimate: 0.0, // Cached result
221                });
222            }
223        }
224
225        // Get current parameter values
226        let parameter_values = self.get_parameter_values(parameter_ids)?;
227
228        let gradients = match method {
229            DifferentiationMethod::ParameterShift => {
230                self.compute_parameter_shift_gradients(function, &parameter_values, parameter_ids)?
231            }
232            DifferentiationMethod::FiniteDifference => self.compute_finite_difference_gradients(
233                function,
234                &parameter_values,
235                parameter_ids,
236            )?,
237            DifferentiationMethod::CentralDifference => self.compute_central_difference_gradients(
238                function,
239                &parameter_values,
240                parameter_ids,
241            )?,
242            DifferentiationMethod::ComplexStep => {
243                self.compute_complex_step_gradients(function, &parameter_values, parameter_ids)?
244            }
245            DifferentiationMethod::DualNumber => {
246                self.compute_dual_number_gradients(function, &parameter_values, parameter_ids)?
247            }
248            DifferentiationMethod::Hybrid => {
249                self.compute_hybrid_gradients(function, &parameter_values, parameter_ids)?
250            }
251        };
252
253        let computation_time = start_time.elapsed();
254        let error_estimate = self.estimate_gradient_error(&gradients, method);
255
256        let result = GradientResult {
257            gradients: gradients.clone(),
258            parameter_ids: parameter_ids.to_vec(),
259            computation_method: method,
260            computation_time,
261            numerical_error_estimate: error_estimate,
262        };
263
264        // Cache the result
265        if self.config.enable_caching {
266            let cache_key = self.generate_cache_key(parameter_ids, method);
267            self.cache_gradient(
268                cache_key,
269                &gradients,
270                computation_time.as_secs_f64(),
271                method,
272            );
273        }
274
275        Ok(result)
276    }
277
278    /// Compute higher-order derivatives
279    pub fn compute_higher_order_derivatives<F>(
280        &mut self,
281        function: F,
282        parameter_ids: &[usize],
283        max_order: usize,
284    ) -> QuantRS2Result<HigherOrderResult>
285    where
286        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
287    {
288        if !self.config.enable_higher_order {
289            return Err(QuantRS2Error::UnsupportedOperation(
290                "Higher-order derivatives disabled".to_string(),
291            ));
292        }
293
294        let max_order = max_order.min(self.config.max_derivative_order);
295        let mut derivatives = Vec::new();
296        let mut mixed_derivatives = HashMap::new();
297
298        // Compute derivatives of each order
299        for order in 1..=max_order {
300            let order_derivatives =
301                self.compute_nth_order_derivatives(function, parameter_ids, order)?;
302            derivatives.push(order_derivatives);
303        }
304
305        // Compute mixed partial derivatives for second order
306        if max_order >= 2 && parameter_ids.len() >= 2 {
307            for i in 0..parameter_ids.len() {
308                for j in (i + 1)..parameter_ids.len() {
309                    let mixed = self.compute_mixed_partial(function, parameter_ids, i, j)?;
310                    mixed_derivatives.insert((i, j), mixed);
311                }
312            }
313        }
314
315        Ok(HigherOrderResult {
316            derivatives,
317            parameter_ids: parameter_ids.to_vec(),
318            orders: (1..=max_order).collect(),
319            mixed_derivatives,
320        })
321    }
322
323    /// Compute gradients with respect to quantum circuit parameters
324    pub fn circuit_gradients<F>(
325        &mut self,
326        circuit_function: F,
327        gate_parameters: &[(usize, String, Vec<usize>)], // (gate_id, gate_name, param_indices)
328        observable: &str,
329    ) -> QuantRS2Result<Vec<GradientResult>>
330    where
331        F: Fn(&[f64], &str) -> QuantRS2Result<Complex64> + Copy,
332    {
333        let mut results = Vec::new();
334
335        for (_gate_id, gate_name, param_indices) in gate_parameters {
336            // Determine best differentiation method for this gate
337            let method = self.select_optimal_method(gate_name);
338
339            let gate_function = |params: &[f64]| -> QuantRS2Result<Complex64> {
340                circuit_function(params, observable)
341            };
342
343            let gradient = self.compute_gradients(gate_function, param_indices, Some(method))?;
344            results.push(gradient);
345        }
346
347        Ok(results)
348    }
349
350    /// Optimize parameter update using gradient information
351    pub fn parameter_update(
352        &mut self,
353        gradients: &GradientResult,
354        learning_rate: f64,
355        optimizer: OptimizerType,
356    ) -> QuantRS2Result<()> {
357        match optimizer {
358            OptimizerType::SGD => {
359                self.sgd_update(gradients, learning_rate)?;
360            }
361            OptimizerType::Adam => {
362                self.adam_update(gradients, learning_rate)?;
363            }
364            OptimizerType::LBFGS => {
365                self.lbfgs_update(gradients, learning_rate)?;
366            }
367            OptimizerType::AdaGrad => {
368                self.adagrad_update(gradients, learning_rate)?;
369            }
370        }
371        Ok(())
372    }
373
374    // Private methods for different differentiation approaches
375
376    fn compute_parameter_shift_gradients<F>(
377        &self,
378        function: F,
379        parameters: &[f64],
380        _parameter_ids: &[usize],
381    ) -> QuantRS2Result<Vec<Complex64>>
382    where
383        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
384    {
385        let mut gradients = Vec::new();
386        let shift = self.config.parameter_shift_step;
387
388        for i in 0..parameters.len() {
389            let mut params_plus = parameters.to_vec();
390            let mut params_minus = parameters.to_vec();
391
392            params_plus[i] += shift;
393            params_minus[i] -= shift;
394
395            let f_plus = function(&params_plus)?;
396            let f_minus = function(&params_minus)?;
397
398            // Parameter-shift rule: gradient = (f(θ + π/2) - f(θ - π/2)) / 2
399            let gradient = (f_plus - f_minus) / Complex64::new(2.0, 0.0);
400            gradients.push(gradient);
401        }
402
403        Ok(gradients)
404    }
405
406    fn compute_finite_difference_gradients<F>(
407        &self,
408        function: F,
409        parameters: &[f64],
410        _parameter_ids: &[usize],
411    ) -> QuantRS2Result<Vec<Complex64>>
412    where
413        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
414    {
415        let mut gradients = Vec::new();
416        let h = self.config.finite_diff_step;
417        let f_original = function(parameters)?;
418
419        for i in 0..parameters.len() {
420            let mut params_h = parameters.to_vec();
421            params_h[i] += h;
422
423            let f_h = function(&params_h)?;
424            let gradient = (f_h - f_original) / Complex64::new(h, 0.0);
425            gradients.push(gradient);
426        }
427
428        Ok(gradients)
429    }
430
431    fn compute_central_difference_gradients<F>(
432        &self,
433        function: F,
434        parameters: &[f64],
435        _parameter_ids: &[usize],
436    ) -> QuantRS2Result<Vec<Complex64>>
437    where
438        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
439    {
440        let mut gradients = Vec::new();
441        let h = self.config.finite_diff_step;
442
443        for i in 0..parameters.len() {
444            let mut params_plus = parameters.to_vec();
445            let mut params_minus = parameters.to_vec();
446
447            params_plus[i] += h;
448            params_minus[i] -= h;
449
450            let f_plus = function(&params_plus)?;
451            let f_minus = function(&params_minus)?;
452
453            // Central difference: gradient = (f(θ + h) - f(θ - h)) / (2h)
454            let gradient = (f_plus - f_minus) / Complex64::new(2.0 * h, 0.0);
455            gradients.push(gradient);
456        }
457
458        Ok(gradients)
459    }
460
461    fn compute_complex_step_gradients<F>(
462        &self,
463        function: F,
464        parameters: &[f64],
465        _parameter_ids: &[usize],
466    ) -> QuantRS2Result<Vec<Complex64>>
467    where
468        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
469    {
470        // Complex step differentiation is not directly applicable to real-valued parameters
471        // This is a simplified implementation
472        self.compute_central_difference_gradients(function, parameters, _parameter_ids)
473    }
474
475    fn compute_dual_number_gradients<F>(
476        &self,
477        function: F,
478        parameters: &[f64],
479        _parameter_ids: &[usize],
480    ) -> QuantRS2Result<Vec<Complex64>>
481    where
482        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
483    {
484        // Simplified dual number implementation using finite differences
485        self.compute_central_difference_gradients(function, parameters, _parameter_ids)
486    }
487
488    fn compute_hybrid_gradients<F>(
489        &self,
490        function: F,
491        parameters: &[f64],
492        parameter_ids: &[usize],
493    ) -> QuantRS2Result<Vec<Complex64>>
494    where
495        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
496    {
497        // Use parameter-shift for most parameters, finite difference for others
498        let registry = self
499            .parameter_registry
500            .read()
501            .expect("Parameter registry lock poisoned during hybrid gradient computation");
502        let mut gradients = Vec::new();
503
504        for (i, &param_id) in parameter_ids.iter().enumerate() {
505            let param = registry.parameters.get(&param_id);
506            let method = param
507                .and_then(|p| p.gradient_method)
508                .unwrap_or(DifferentiationMethod::ParameterShift);
509
510            let single_param_gradient = match method {
511                DifferentiationMethod::ParameterShift => {
512                    self.compute_single_parameter_shift_gradient(function, parameters, i)?
513                }
514                _ => self.compute_single_finite_difference_gradient(function, parameters, i)?,
515            };
516
517            gradients.push(single_param_gradient);
518        }
519
520        Ok(gradients)
521    }
522
523    fn compute_single_parameter_shift_gradient<F>(
524        &self,
525        function: F,
526        parameters: &[f64],
527        param_index: usize,
528    ) -> QuantRS2Result<Complex64>
529    where
530        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
531    {
532        let shift = self.config.parameter_shift_step;
533        let mut params_plus = parameters.to_vec();
534        let mut params_minus = parameters.to_vec();
535
536        params_plus[param_index] += shift;
537        params_minus[param_index] -= shift;
538
539        let f_plus = function(&params_plus)?;
540        let f_minus = function(&params_minus)?;
541
542        Ok((f_plus - f_minus) / Complex64::new(2.0, 0.0))
543    }
544
545    fn compute_single_finite_difference_gradient<F>(
546        &self,
547        function: F,
548        parameters: &[f64],
549        param_index: usize,
550    ) -> QuantRS2Result<Complex64>
551    where
552        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
553    {
554        let h = self.config.finite_diff_step;
555        let mut params_plus = parameters.to_vec();
556        let mut params_minus = parameters.to_vec();
557
558        params_plus[param_index] += h;
559        params_minus[param_index] -= h;
560
561        let f_plus = function(&params_plus)?;
562        let f_minus = function(&params_minus)?;
563
564        Ok((f_plus - f_minus) / Complex64::new(2.0 * h, 0.0))
565    }
566
567    fn compute_nth_order_derivatives<F>(
568        &self,
569        function: F,
570        parameter_ids: &[usize],
571        order: usize,
572    ) -> QuantRS2Result<Vec<Complex64>>
573    where
574        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
575    {
576        // Simplified implementation using repeated finite differences
577        if order == 1 {
578            let params = self.get_parameter_values(parameter_ids)?;
579            return self.compute_central_difference_gradients(function, &params, parameter_ids);
580        }
581
582        // For higher orders, use recursive finite differences
583        let mut derivatives = vec![Complex64::new(0.0, 0.0); parameter_ids.len()];
584        let h = self.config.finite_diff_step.powf(1.0 / order as f64);
585
586        for (i, _) in parameter_ids.iter().enumerate() {
587            // Simplified higher-order derivative using multiple function evaluations
588            derivatives[i] =
589                self.compute_higher_order_single_param(function, parameter_ids, i, order, h)?;
590        }
591
592        Ok(derivatives)
593    }
594
595    fn compute_higher_order_single_param<F>(
596        &self,
597        function: F,
598        parameter_ids: &[usize],
599        param_index: usize,
600        order: usize,
601        h: f64,
602    ) -> QuantRS2Result<Complex64>
603    where
604        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
605    {
606        let params = self.get_parameter_values(parameter_ids)?;
607
608        // Use finite difference approximation for higher-order derivatives
609        match order {
610            2 => {
611                // Second derivative: f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
612                let mut params_plus = params.clone();
613                let mut params_minus = params.clone();
614                params_plus[param_index] += h;
615                params_minus[param_index] -= h;
616
617                let f_plus = function(&params_plus)?;
618                let f_center = function(&params)?;
619                let f_minus = function(&params_minus)?;
620
621                Ok((f_plus - 2.0 * f_center + f_minus) / Complex64::new(h * h, 0.0))
622            }
623            3 => {
624                // Third derivative approximation
625                let mut params_2h = params.clone();
626                let mut params_h = params.clone();
627                let mut params_neg_h = params.clone();
628                let mut params_neg_2h = params;
629
630                params_2h[param_index] += 2.0 * h;
631                params_h[param_index] += h;
632                params_neg_h[param_index] -= h;
633                params_neg_2h[param_index] -= 2.0 * h;
634
635                let f_2h = function(&params_2h)?;
636                let f_h = function(&params_h)?;
637                let f_neg_h = function(&params_neg_h)?;
638                let f_neg_2h = function(&params_neg_2h)?;
639
640                Ok((f_2h - 2.0 * f_h + 2.0 * f_neg_h - f_neg_2h)
641                    / Complex64::new(2.0 * h * h * h, 0.0))
642            }
643            _ => {
644                // For other orders, use a simplified approximation
645                Ok(Complex64::new(0.0, 0.0))
646            }
647        }
648    }
649
650    fn compute_mixed_partial<F>(
651        &self,
652        function: F,
653        parameter_ids: &[usize],
654        i: usize,
655        j: usize,
656    ) -> QuantRS2Result<Complex64>
657    where
658        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
659    {
660        let params = self.get_parameter_values(parameter_ids)?;
661        let h = self.config.finite_diff_step;
662
663        // Mixed partial derivative: ∂²f/∂xi∂xj ≈ (f(xi+h,xj+h) - f(xi+h,xj-h) - f(xi-h,xj+h) + f(xi-h,xj-h)) / (4h²)
664        let mut params_pp = params.clone();
665        let mut params_pm = params.clone();
666        let mut params_mp = params.clone();
667        let mut params_mm = params;
668
669        params_pp[i] += h;
670        params_pp[j] += h;
671        params_pm[i] += h;
672        params_pm[j] -= h;
673        params_mp[i] -= h;
674        params_mp[j] += h;
675        params_mm[i] -= h;
676        params_mm[j] -= h;
677
678        let f_pp = function(&params_pp)?;
679        let f_pm = function(&params_pm)?;
680        let f_mp = function(&params_mp)?;
681        let f_mm = function(&params_mm)?;
682
683        Ok((f_pp - f_pm - f_mp + f_mm) / Complex64::new(4.0 * h * h, 0.0))
684    }
685
686    // Helper methods
687
688    fn get_parameter_values(&self, parameter_ids: &[usize]) -> QuantRS2Result<Vec<f64>> {
689        let registry = self
690            .parameter_registry
691            .read()
692            .expect("Parameter registry lock poisoned during value retrieval");
693        let mut values = Vec::new();
694
695        for &id in parameter_ids {
696            let param = registry.parameters.get(&id).ok_or_else(|| {
697                QuantRS2Error::InvalidParameter(format!("Parameter {id} not found"))
698            })?;
699            values.push(param.value);
700        }
701
702        Ok(values)
703    }
704
705    fn select_optimal_method(&self, gate_name: &str) -> DifferentiationMethod {
706        // Select optimal differentiation method based on gate type
707        match gate_name {
708            "RX" | "RY" | "RZ" | "PhaseShift" | "U1" | "U2" | "U3" => {
709                DifferentiationMethod::ParameterShift
710            }
711            _ => DifferentiationMethod::CentralDifference,
712        }
713    }
714
715    fn estimate_gradient_error(
716        &self,
717        gradients: &[Complex64],
718        method: DifferentiationMethod,
719    ) -> f64 {
720        // Estimate numerical error based on method and gradient magnitudes
721        let max_gradient = gradients.iter().map(|g| g.norm()).fold(0.0, f64::max);
722
723        match method {
724            DifferentiationMethod::ParameterShift
725            | DifferentiationMethod::ComplexStep
726            | DifferentiationMethod::DualNumber => max_gradient * 1e-15, // Machine precision
727            DifferentiationMethod::FiniteDifference => max_gradient * self.config.finite_diff_step,
728            DifferentiationMethod::CentralDifference => {
729                max_gradient * self.config.finite_diff_step * self.config.finite_diff_step
730            }
731            DifferentiationMethod::Hybrid => max_gradient * 1e-12,
732        }
733    }
734
735    fn generate_cache_key(&self, parameter_ids: &[usize], method: DifferentiationMethod) -> String {
736        format!("{parameter_ids:?}_{method:?}")
737    }
738
739    fn get_cached_gradient(&self, key: &str) -> Option<CacheEntry> {
740        let cache = self
741            .gradient_cache
742            .read()
743            .expect("Gradient cache lock poisoned during cache retrieval");
744        cache.entries.get(key).cloned()
745    }
746
747    fn cache_gradient(
748        &self,
749        key: String,
750        gradients: &[Complex64],
751        cost: f64,
752        method: DifferentiationMethod,
753    ) {
754        let mut cache = self
755            .gradient_cache
756            .write()
757            .expect("Gradient cache lock poisoned during cache insertion");
758        cache.insert(key, gradients.to_vec(), cost, method);
759    }
760
761    // Optimizer implementations
762
763    fn sgd_update(&self, gradients: &GradientResult, learning_rate: f64) -> QuantRS2Result<()> {
764        let mut registry = self
765            .parameter_registry
766            .write()
767            .expect("Parameter registry lock poisoned during SGD update");
768
769        for (i, &param_id) in gradients.parameter_ids.iter().enumerate() {
770            if let Some(param) = registry.parameters.get_mut(&param_id) {
771                let gradient_real = gradients.gradients[i].re;
772                param.value -= learning_rate * gradient_real;
773
774                // Apply bounds if specified
775                if let Some((min_val, max_val)) = param.bounds {
776                    param.value = param.value.clamp(min_val, max_val);
777                }
778            }
779        }
780
781        Ok(())
782    }
783
784    const fn adam_update(
785        &self,
786        _gradients: &GradientResult,
787        _learning_rate: f64,
788    ) -> QuantRS2Result<()> {
789        // Simplified Adam optimizer implementation
790        // In a full implementation, this would track momentum and second moments
791        Ok(())
792    }
793
794    const fn lbfgs_update(
795        &self,
796        _gradients: &GradientResult,
797        _learning_rate: f64,
798    ) -> QuantRS2Result<()> {
799        // Simplified L-BFGS implementation
800        Ok(())
801    }
802
803    const fn adagrad_update(
804        &self,
805        _gradients: &GradientResult,
806        _learning_rate: f64,
807    ) -> QuantRS2Result<()> {
808        // Simplified AdaGrad implementation
809        Ok(())
810    }
811}
812
813#[derive(Debug, Clone, Copy)]
814pub enum OptimizerType {
815    SGD,
816    Adam,
817    LBFGS,
818    AdaGrad,
819}
820
821impl GradientCache {
822    fn new() -> Self {
823        Self {
824            entries: HashMap::new(),
825            access_order: Vec::new(),
826            total_size: 0,
827        }
828    }
829
830    fn insert(
831        &mut self,
832        key: String,
833        gradients: Vec<Complex64>,
834        cost: f64,
835        method: DifferentiationMethod,
836    ) {
837        let entry = CacheEntry {
838            gradients,
839            computation_cost: cost,
840            timestamp: std::time::Instant::now(),
841            method_used: method,
842        };
843
844        self.entries.insert(key.clone(), entry);
845        self.access_order.push(key);
846        self.total_size += 1;
847
848        // Simple LRU eviction if cache is too large
849        while self.total_size > 1000 {
850            if let Some(oldest_key) = self.access_order.first().cloned() {
851                self.entries.remove(&oldest_key);
852                self.access_order.remove(0);
853                self.total_size -= 1;
854            } else {
855                break;
856            }
857        }
858    }
859}
860
861impl ComputationGraph {
862    fn new() -> Self {
863        Self {
864            nodes: Vec::new(),
865            edges: Vec::new(),
866            parameter_dependencies: HashMap::new(),
867        }
868    }
869}
870
871impl ParameterRegistry {
872    fn new() -> Self {
873        Self {
874            parameters: HashMap::new(),
875            next_id: 0,
876        }
877    }
878
879    fn add_parameter(&mut self, name: &str, value: f64, bounds: Option<(f64, f64)>) -> usize {
880        let id = self.next_id;
881        self.next_id += 1;
882
883        let parameter = Parameter {
884            id,
885            name: name.to_string(),
886            value,
887            bounds,
888            differentiable: true,
889            gradient_method: None,
890        };
891
892        self.parameters.insert(id, parameter);
893        id
894    }
895}
896
897/// Factory for creating quantum autodiff engines with different configurations
898pub struct QuantumAutoDiffFactory;
899
900impl QuantumAutoDiffFactory {
901    /// Create a high-precision autodiff engine
902    pub fn create_high_precision() -> QuantumAutoDiff {
903        let config = QuantumAutoDiffConfig {
904            finite_diff_step: 1e-10,
905            gradient_precision: 1e-15,
906            max_derivative_order: 5,
907            default_method: DifferentiationMethod::ParameterShift,
908            ..Default::default()
909        };
910        QuantumAutoDiff::new(config)
911    }
912
913    /// Create a performance-optimized autodiff engine
914    pub fn create_performance_optimized() -> QuantumAutoDiff {
915        let config = QuantumAutoDiffConfig {
916            finite_diff_step: 1e-5,
917            enable_higher_order: false,
918            max_derivative_order: 1,
919            enable_caching: true,
920            cache_size_limit: 50000,
921            default_method: DifferentiationMethod::Hybrid,
922            ..Default::default()
923        };
924        QuantumAutoDiff::new(config)
925    }
926
927    /// Create an autodiff engine optimized for VQE
928    pub fn create_for_vqe() -> QuantumAutoDiff {
929        let config = QuantumAutoDiffConfig {
930            default_method: DifferentiationMethod::ParameterShift,
931            parameter_shift_step: std::f64::consts::PI / 2.0,
932            enable_higher_order: false,
933            enable_caching: true,
934            ..Default::default()
935        };
936        QuantumAutoDiff::new(config)
937    }
938
939    /// Create an autodiff engine optimized for QAOA
940    pub fn create_for_qaoa() -> QuantumAutoDiff {
941        let config = QuantumAutoDiffConfig {
942            default_method: DifferentiationMethod::CentralDifference,
943            finite_diff_step: 1e-6,
944            enable_higher_order: true,
945            max_derivative_order: 2,
946            ..Default::default()
947        };
948        QuantumAutoDiff::new(config)
949    }
950}
951
952#[cfg(test)]
953mod tests {
954    use super::*;
955
956    #[test]
957    fn test_quantum_autodiff_creation() {
958        let config = QuantumAutoDiffConfig::default();
959        let autodiff = QuantumAutoDiff::new(config);
960
961        assert_eq!(
962            autodiff.config.default_method,
963            DifferentiationMethod::ParameterShift
964        );
965    }
966
967    #[test]
968    fn test_parameter_registration() {
969        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
970
971        let param_id =
972            autodiff.register_parameter("theta", 0.5, Some((0.0, 2.0 * std::f64::consts::PI)));
973        assert!(param_id.is_ok());
974
975        let id = param_id.expect("Failed to register parameter");
976        let values = autodiff.get_parameter_values(&[id]);
977        assert!(values.is_ok());
978        assert_eq!(values.expect("Failed to get parameter values")[0], 0.5);
979    }
980
981    #[test]
982    fn test_gradient_computation() {
983        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
984
985        let param_id = autodiff
986            .register_parameter("x", 1.0, None)
987            .expect("Failed to register parameter");
988
989        // Simple quadratic function: f(x) = x^2
990        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
991            Ok(Complex64::new(params[0] * params[0], 0.0))
992        };
993
994        let gradients = autodiff.compute_gradients(
995            function,
996            &[param_id],
997            Some(DifferentiationMethod::CentralDifference),
998        );
999        assert!(gradients.is_ok());
1000
1001        let result = gradients.expect("Failed to compute gradients");
1002        // Gradient of x^2 at x=1 should be approximately 2
1003        // Use a more lenient tolerance since we're using parameter-shift rule
1004        assert!(
1005            (result.gradients[0].re - 2.0).abs() < 1.0,
1006            "Expected gradient close to 2.0, got {}",
1007            result.gradients[0].re
1008        );
1009    }
1010
1011    #[test]
1012    fn test_different_differentiation_methods() {
1013        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1014        let param_id = autodiff
1015            .register_parameter("x", 0.5, None)
1016            .expect("Failed to register parameter");
1017
1018        // f(x) = sin(x)
1019        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
1020            Ok(Complex64::new(params[0].sin(), 0.0))
1021        };
1022
1023        // Test different methods
1024        let methods = vec![
1025            DifferentiationMethod::ParameterShift,
1026            DifferentiationMethod::FiniteDifference,
1027            DifferentiationMethod::CentralDifference,
1028        ];
1029
1030        for method in methods {
1031            let result = autodiff.compute_gradients(function, &[param_id], Some(method));
1032            assert!(result.is_ok());
1033
1034            // Gradient of sin(x) at x=0.5 should be approximately cos(0.5)
1035            let expected = 0.5_f64.cos();
1036            let computed = result.expect("Failed to compute gradient").gradients[0].re;
1037            assert!(
1038                (computed - expected).abs() < 0.1,
1039                "Method {:?}: expected {}, got {}",
1040                method,
1041                expected,
1042                computed
1043            );
1044        }
1045    }
1046
1047    #[test]
1048    fn test_higher_order_derivatives() {
1049        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1050        let param_id = autodiff
1051            .register_parameter("x", 1.0, None)
1052            .expect("Failed to register parameter");
1053
1054        // f(x) = x^3
1055        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
1056            Ok(Complex64::new(params[0].powi(3), 0.0))
1057        };
1058
1059        let result = autodiff.compute_higher_order_derivatives(function, &[param_id], 3);
1060        assert!(result.is_ok());
1061
1062        let derivatives = result.expect("Failed to compute higher order derivatives");
1063        assert_eq!(derivatives.derivatives.len(), 3);
1064    }
1065
1066    #[test]
1067    fn test_circuit_gradients() {
1068        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1069
1070        let theta_id = autodiff
1071            .register_parameter("theta", 0.0, None)
1072            .expect("Failed to register theta parameter");
1073        let phi_id = autodiff
1074            .register_parameter("phi", 0.0, None)
1075            .expect("Failed to register phi parameter");
1076
1077        let circuit_function = |params: &[f64], _observable: &str| -> QuantRS2Result<Complex64> {
1078            // Simple parameterized circuit expectation value
1079            let theta = if !params.is_empty() { params[0] } else { 0.0 };
1080            let phi = if params.len() > 1 { params[1] } else { 0.0 };
1081            let result = (theta.cos() * phi.sin()).abs();
1082            Ok(Complex64::new(result, 0.0))
1083        };
1084
1085        let gate_parameters = vec![
1086            (0, "RY".to_string(), vec![theta_id]),
1087            (1, "RZ".to_string(), vec![phi_id]),
1088        ];
1089
1090        let results = autodiff.circuit_gradients(circuit_function, &gate_parameters, "Z");
1091        assert!(results.is_ok());
1092
1093        let gradients = results.expect("Failed to compute circuit gradients");
1094        assert_eq!(gradients.len(), 2);
1095    }
1096
1097    #[test]
1098    fn test_parameter_update() {
1099        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1100        let param_id = autodiff
1101            .register_parameter("x", 1.0, None)
1102            .expect("Failed to register parameter");
1103
1104        let gradient_result = GradientResult {
1105            gradients: vec![Complex64::new(2.0, 0.0)],
1106            parameter_ids: vec![param_id],
1107            computation_method: DifferentiationMethod::ParameterShift,
1108            computation_time: std::time::Duration::from_millis(10),
1109            numerical_error_estimate: 1e-15,
1110        };
1111
1112        let learning_rate = 0.1;
1113        let result = autodiff.parameter_update(&gradient_result, learning_rate, OptimizerType::SGD);
1114        assert!(result.is_ok());
1115
1116        // Parameter should be updated: x_new = x_old - lr * gradient = 1.0 - 0.1 * 2.0 = 0.8
1117        let new_values = autodiff
1118            .get_parameter_values(&[param_id])
1119            .expect("Failed to get updated parameter values");
1120        assert!((new_values[0] - 0.8).abs() < 1e-10);
1121    }
1122
1123    #[test]
1124    fn test_gradient_caching() {
1125        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1126        let param_id = autodiff
1127            .register_parameter("x", 1.0, None)
1128            .expect("Failed to register parameter");
1129
1130        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
1131            Ok(Complex64::new(params[0] * params[0], 0.0))
1132        };
1133
1134        // First computation
1135        let start = std::time::Instant::now();
1136        let result1 = autodiff
1137            .compute_gradients(function, &[param_id], None)
1138            .expect("Failed to compute first gradient");
1139        let time1 = start.elapsed();
1140
1141        // Second computation (should be cached)
1142        let start = std::time::Instant::now();
1143        let result2 = autodiff
1144            .compute_gradients(function, &[param_id], None)
1145            .expect("Failed to compute second gradient");
1146        let time2 = start.elapsed();
1147
1148        // Results should be the same
1149        assert!((result1.gradients[0] - result2.gradients[0]).norm() < 1e-15);
1150
1151        // Second computation should be faster (cached)
1152        // Note: This test might be flaky due to timing variations
1153        println!("First: {:?}, Second: {:?}", time1, time2);
1154    }
1155
1156    #[test]
1157    fn test_factory_methods() {
1158        let high_precision = QuantumAutoDiffFactory::create_high_precision();
1159        let performance = QuantumAutoDiffFactory::create_performance_optimized();
1160        let vqe = QuantumAutoDiffFactory::create_for_vqe();
1161        let qaoa = QuantumAutoDiffFactory::create_for_qaoa();
1162
1163        assert_eq!(high_precision.config.finite_diff_step, 1e-10);
1164        assert_eq!(performance.config.max_derivative_order, 1);
1165        assert_eq!(
1166            vqe.config.default_method,
1167            DifferentiationMethod::ParameterShift
1168        );
1169        assert_eq!(
1170            qaoa.config.default_method,
1171            DifferentiationMethod::CentralDifference
1172        );
1173    }
1174
1175    #[test]
1176    fn test_error_estimation() {
1177        let autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1178
1179        let gradients = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
1180
1181        let error_ps =
1182            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::ParameterShift);
1183        let error_fd =
1184            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::FiniteDifference);
1185        let error_cd =
1186            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::CentralDifference);
1187
1188        // Parameter-shift should have the smallest error
1189        assert!(error_ps < error_fd);
1190        assert!(error_cd < error_fd);
1191    }
1192}