quantrs2_core/
quantum_autodiff.rs

1//! Enhanced Automatic Differentiation for Quantum Gradients
2//!
3//! This module provides advanced automatic differentiation capabilities
4//! specifically designed for quantum computing, including parameter-shift
5//! rules, finite differences, and hybrid classical-quantum gradients.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::Complex64;
9use std::{
10    collections::HashMap,
11    fmt,
12    sync::{Arc, RwLock},
13};
14
15/// Configuration for quantum automatic differentiation
16#[derive(Debug, Clone)]
17pub struct QuantumAutoDiffConfig {
18    /// Default differentiation method
19    pub default_method: DifferentiationMethod,
20    /// Finite difference step size
21    pub finite_diff_step: f64,
22    /// Parameter-shift rule step size
23    pub parameter_shift_step: f64,
24    /// Enable higher-order derivatives
25    pub enable_higher_order: bool,
26    /// Maximum order of derivatives to compute
27    pub max_derivative_order: usize,
28    /// Gradient computation precision
29    pub gradient_precision: f64,
30    /// Enable gradient caching
31    pub enable_caching: bool,
32    /// Cache size limit
33    pub cache_size_limit: usize,
34}
35
36impl Default for QuantumAutoDiffConfig {
37    fn default() -> Self {
38        Self {
39            default_method: DifferentiationMethod::ParameterShift,
40            finite_diff_step: 1e-7,
41            parameter_shift_step: std::f64::consts::PI / 2.0,
42            enable_higher_order: true,
43            max_derivative_order: 3,
44            gradient_precision: 1e-12,
45            enable_caching: true,
46            cache_size_limit: 10000,
47        }
48    }
49}
50
51/// Methods for computing quantum gradients
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DifferentiationMethod {
54    /// Parameter-shift rule (exact for many quantum gates)
55    ParameterShift,
56    /// Finite differences (numerical approximation)
57    FiniteDifference,
58    /// Central differences (more accurate numerical)
59    CentralDifference,
60    /// Complex step differentiation
61    ComplexStep,
62    /// Automatic differentiation using dual numbers
63    DualNumber,
64    /// Hybrid method (automatic selection)
65    Hybrid,
66}
67
68impl fmt::Display for DifferentiationMethod {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            DifferentiationMethod::ParameterShift => write!(f, "Parameter-Shift"),
72            DifferentiationMethod::FiniteDifference => write!(f, "Finite Difference"),
73            DifferentiationMethod::CentralDifference => write!(f, "Central Difference"),
74            DifferentiationMethod::ComplexStep => write!(f, "Complex Step"),
75            DifferentiationMethod::DualNumber => write!(f, "Dual Number"),
76            DifferentiationMethod::Hybrid => write!(f, "Hybrid"),
77        }
78    }
79}
80
81/// Quantum automatic differentiation engine
82#[derive(Debug)]
83pub struct QuantumAutoDiff {
84    config: QuantumAutoDiffConfig,
85    gradient_cache: Arc<RwLock<GradientCache>>,
86    computation_graph: Arc<RwLock<ComputationGraph>>,
87    parameter_registry: Arc<RwLock<ParameterRegistry>>,
88}
89
90/// Cache for computed gradients
91#[derive(Debug)]
92pub struct GradientCache {
93    entries: HashMap<String, CacheEntry>,
94    access_order: Vec<String>,
95    total_size: usize,
96}
97
98#[derive(Debug, Clone)]
99pub struct CacheEntry {
100    gradients: Vec<Complex64>,
101    computation_cost: f64,
102    timestamp: std::time::Instant,
103    method_used: DifferentiationMethod,
104}
105
106/// Computation graph for tracking quantum operations
107#[derive(Debug)]
108pub struct ComputationGraph {
109    nodes: Vec<ComputationNode>,
110    edges: Vec<ComputationEdge>,
111    parameter_dependencies: HashMap<usize, Vec<usize>>,
112}
113
114#[derive(Debug, Clone)]
115pub struct ComputationNode {
116    id: usize,
117    operation: QuantumOperation,
118    inputs: Vec<usize>,
119    outputs: Vec<usize>,
120}
121
122#[derive(Debug, Clone)]
123pub struct ComputationEdge {
124    from: usize,
125    to: usize,
126    parameter_id: Option<usize>,
127}
128
129#[derive(Debug, Clone)]
130pub enum QuantumOperation {
131    Gate { name: String, parameters: Vec<f64> },
132    Measurement { observable: String },
133    StatePreparation { amplitudes: Vec<Complex64> },
134    Expectation { observable: String },
135}
136
137/// Registry for tracking parameters
138#[derive(Debug)]
139pub struct ParameterRegistry {
140    parameters: HashMap<usize, Parameter>,
141    next_id: usize,
142}
143
144#[derive(Debug, Clone)]
145pub struct Parameter {
146    id: usize,
147    name: String,
148    value: f64,
149    bounds: Option<(f64, f64)>,
150    differentiable: bool,
151    gradient_method: Option<DifferentiationMethod>,
152}
153
154/// Result of gradient computation
155#[derive(Debug, Clone)]
156pub struct GradientResult {
157    pub gradients: Vec<Complex64>,
158    pub parameter_ids: Vec<usize>,
159    pub computation_method: DifferentiationMethod,
160    pub computation_time: std::time::Duration,
161    pub numerical_error_estimate: f64,
162}
163
164/// Higher-order derivative result
165#[derive(Debug, Clone)]
166pub struct HigherOrderResult {
167    pub derivatives: Vec<Vec<Complex64>>, // derivatives[order][parameter]
168    pub parameter_ids: Vec<usize>,
169    pub orders: Vec<usize>,
170    pub mixed_derivatives: HashMap<(usize, usize), Complex64>,
171}
172
173impl QuantumAutoDiff {
174    /// Create a new quantum automatic differentiation engine
175    pub fn new(config: QuantumAutoDiffConfig) -> Self {
176        Self {
177            config,
178            gradient_cache: Arc::new(RwLock::new(GradientCache::new())),
179            computation_graph: Arc::new(RwLock::new(ComputationGraph::new())),
180            parameter_registry: Arc::new(RwLock::new(ParameterRegistry::new())),
181        }
182    }
183
184    /// Register a parameter for differentiation
185    pub fn register_parameter(
186        &mut self,
187        name: &str,
188        initial_value: f64,
189        bounds: Option<(f64, f64)>,
190    ) -> QuantRS2Result<usize> {
191        let mut registry = self.parameter_registry.write().unwrap();
192        Ok(registry.add_parameter(name, initial_value, bounds))
193    }
194
195    /// Compute gradients using the specified method
196    pub fn compute_gradients<F>(
197        &mut self,
198        function: F,
199        parameter_ids: &[usize],
200        method: Option<DifferentiationMethod>,
201    ) -> QuantRS2Result<GradientResult>
202    where
203        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
204    {
205        let method = method.unwrap_or(self.config.default_method);
206        let start_time = std::time::Instant::now();
207
208        // Check cache first
209        if self.config.enable_caching {
210            let cache_key = self.generate_cache_key(parameter_ids, method);
211            if let Some(cached) = self.get_cached_gradient(&cache_key) {
212                return Ok(GradientResult {
213                    gradients: cached.gradients,
214                    parameter_ids: parameter_ids.to_vec(),
215                    computation_method: cached.method_used,
216                    computation_time: start_time.elapsed(),
217                    numerical_error_estimate: 0.0, // Cached result
218                });
219            }
220        }
221
222        // Get current parameter values
223        let parameter_values = self.get_parameter_values(parameter_ids)?;
224
225        let gradients = match method {
226            DifferentiationMethod::ParameterShift => {
227                self.compute_parameter_shift_gradients(function, &parameter_values, parameter_ids)?
228            }
229            DifferentiationMethod::FiniteDifference => self.compute_finite_difference_gradients(
230                function,
231                &parameter_values,
232                parameter_ids,
233            )?,
234            DifferentiationMethod::CentralDifference => self.compute_central_difference_gradients(
235                function,
236                &parameter_values,
237                parameter_ids,
238            )?,
239            DifferentiationMethod::ComplexStep => {
240                self.compute_complex_step_gradients(function, &parameter_values, parameter_ids)?
241            }
242            DifferentiationMethod::DualNumber => {
243                self.compute_dual_number_gradients(function, &parameter_values, parameter_ids)?
244            }
245            DifferentiationMethod::Hybrid => {
246                self.compute_hybrid_gradients(function, &parameter_values, parameter_ids)?
247            }
248        };
249
250        let computation_time = start_time.elapsed();
251        let error_estimate = self.estimate_gradient_error(&gradients, method);
252
253        let result = GradientResult {
254            gradients: gradients.clone(),
255            parameter_ids: parameter_ids.to_vec(),
256            computation_method: method,
257            computation_time,
258            numerical_error_estimate: error_estimate,
259        };
260
261        // Cache the result
262        if self.config.enable_caching {
263            let cache_key = self.generate_cache_key(parameter_ids, method);
264            self.cache_gradient(
265                cache_key,
266                &gradients,
267                computation_time.as_secs_f64(),
268                method,
269            );
270        }
271
272        Ok(result)
273    }
274
275    /// Compute higher-order derivatives
276    pub fn compute_higher_order_derivatives<F>(
277        &mut self,
278        function: F,
279        parameter_ids: &[usize],
280        max_order: usize,
281    ) -> QuantRS2Result<HigherOrderResult>
282    where
283        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
284    {
285        if !self.config.enable_higher_order {
286            return Err(QuantRS2Error::UnsupportedOperation(
287                "Higher-order derivatives disabled".to_string(),
288            ));
289        }
290
291        let max_order = max_order.min(self.config.max_derivative_order);
292        let mut derivatives = Vec::new();
293        let mut mixed_derivatives = HashMap::new();
294
295        // Compute derivatives of each order
296        for order in 1..=max_order {
297            let order_derivatives =
298                self.compute_nth_order_derivatives(function, parameter_ids, order)?;
299            derivatives.push(order_derivatives);
300        }
301
302        // Compute mixed partial derivatives for second order
303        if max_order >= 2 && parameter_ids.len() >= 2 {
304            for i in 0..parameter_ids.len() {
305                for j in (i + 1)..parameter_ids.len() {
306                    let mixed = self.compute_mixed_partial(function, parameter_ids, i, j)?;
307                    mixed_derivatives.insert((i, j), mixed);
308                }
309            }
310        }
311
312        Ok(HigherOrderResult {
313            derivatives,
314            parameter_ids: parameter_ids.to_vec(),
315            orders: (1..=max_order).collect(),
316            mixed_derivatives,
317        })
318    }
319
320    /// Compute gradients with respect to quantum circuit parameters
321    pub fn circuit_gradients<F>(
322        &mut self,
323        circuit_function: F,
324        gate_parameters: &[(usize, String, Vec<usize>)], // (gate_id, gate_name, param_indices)
325        observable: &str,
326    ) -> QuantRS2Result<Vec<GradientResult>>
327    where
328        F: Fn(&[f64], &str) -> QuantRS2Result<Complex64> + Copy,
329    {
330        let mut results = Vec::new();
331
332        for (_gate_id, gate_name, param_indices) in gate_parameters {
333            // Determine best differentiation method for this gate
334            let method = self.select_optimal_method(gate_name);
335
336            let gate_function = |params: &[f64]| -> QuantRS2Result<Complex64> {
337                circuit_function(params, observable)
338            };
339
340            let gradient = self.compute_gradients(gate_function, param_indices, Some(method))?;
341            results.push(gradient);
342        }
343
344        Ok(results)
345    }
346
347    /// Optimize parameter update using gradient information
348    pub fn parameter_update(
349        &mut self,
350        gradients: &GradientResult,
351        learning_rate: f64,
352        optimizer: OptimizerType,
353    ) -> QuantRS2Result<()> {
354        match optimizer {
355            OptimizerType::SGD => {
356                self.sgd_update(gradients, learning_rate)?;
357            }
358            OptimizerType::Adam => {
359                self.adam_update(gradients, learning_rate)?;
360            }
361            OptimizerType::LBFGS => {
362                self.lbfgs_update(gradients, learning_rate)?;
363            }
364            OptimizerType::AdaGrad => {
365                self.adagrad_update(gradients, learning_rate)?;
366            }
367        }
368        Ok(())
369    }
370
371    // Private methods for different differentiation approaches
372
373    fn compute_parameter_shift_gradients<F>(
374        &self,
375        function: F,
376        parameters: &[f64],
377        _parameter_ids: &[usize],
378    ) -> QuantRS2Result<Vec<Complex64>>
379    where
380        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
381    {
382        let mut gradients = Vec::new();
383        let shift = self.config.parameter_shift_step;
384
385        for i in 0..parameters.len() {
386            let mut params_plus = parameters.to_vec();
387            let mut params_minus = parameters.to_vec();
388
389            params_plus[i] += shift;
390            params_minus[i] -= shift;
391
392            let f_plus = function(&params_plus)?;
393            let f_minus = function(&params_minus)?;
394
395            // Parameter-shift rule: gradient = (f(θ + π/2) - f(θ - π/2)) / 2
396            let gradient = (f_plus - f_minus) / Complex64::new(2.0, 0.0);
397            gradients.push(gradient);
398        }
399
400        Ok(gradients)
401    }
402
403    fn compute_finite_difference_gradients<F>(
404        &self,
405        function: F,
406        parameters: &[f64],
407        _parameter_ids: &[usize],
408    ) -> QuantRS2Result<Vec<Complex64>>
409    where
410        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
411    {
412        let mut gradients = Vec::new();
413        let h = self.config.finite_diff_step;
414        let f_original = function(parameters)?;
415
416        for i in 0..parameters.len() {
417            let mut params_h = parameters.to_vec();
418            params_h[i] += h;
419
420            let f_h = function(&params_h)?;
421            let gradient = (f_h - f_original) / Complex64::new(h, 0.0);
422            gradients.push(gradient);
423        }
424
425        Ok(gradients)
426    }
427
428    fn compute_central_difference_gradients<F>(
429        &self,
430        function: F,
431        parameters: &[f64],
432        _parameter_ids: &[usize],
433    ) -> QuantRS2Result<Vec<Complex64>>
434    where
435        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
436    {
437        let mut gradients = Vec::new();
438        let h = self.config.finite_diff_step;
439
440        for i in 0..parameters.len() {
441            let mut params_plus = parameters.to_vec();
442            let mut params_minus = parameters.to_vec();
443
444            params_plus[i] += h;
445            params_minus[i] -= h;
446
447            let f_plus = function(&params_plus)?;
448            let f_minus = function(&params_minus)?;
449
450            // Central difference: gradient = (f(θ + h) - f(θ - h)) / (2h)
451            let gradient = (f_plus - f_minus) / Complex64::new(2.0 * h, 0.0);
452            gradients.push(gradient);
453        }
454
455        Ok(gradients)
456    }
457
458    fn compute_complex_step_gradients<F>(
459        &self,
460        function: F,
461        parameters: &[f64],
462        _parameter_ids: &[usize],
463    ) -> QuantRS2Result<Vec<Complex64>>
464    where
465        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
466    {
467        // Complex step differentiation is not directly applicable to real-valued parameters
468        // This is a simplified implementation
469        self.compute_central_difference_gradients(function, parameters, _parameter_ids)
470    }
471
472    fn compute_dual_number_gradients<F>(
473        &self,
474        function: F,
475        parameters: &[f64],
476        _parameter_ids: &[usize],
477    ) -> QuantRS2Result<Vec<Complex64>>
478    where
479        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
480    {
481        // Simplified dual number implementation using finite differences
482        self.compute_central_difference_gradients(function, parameters, _parameter_ids)
483    }
484
485    fn compute_hybrid_gradients<F>(
486        &self,
487        function: F,
488        parameters: &[f64],
489        parameter_ids: &[usize],
490    ) -> QuantRS2Result<Vec<Complex64>>
491    where
492        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
493    {
494        // Use parameter-shift for most parameters, finite difference for others
495        let registry = self.parameter_registry.read().unwrap();
496        let mut gradients = Vec::new();
497
498        for (i, &param_id) in parameter_ids.iter().enumerate() {
499            let param = registry.parameters.get(&param_id);
500            let method = param
501                .and_then(|p| p.gradient_method)
502                .unwrap_or(DifferentiationMethod::ParameterShift);
503
504            let single_param_gradient = match method {
505                DifferentiationMethod::ParameterShift => {
506                    self.compute_single_parameter_shift_gradient(function, parameters, i)?
507                }
508                _ => self.compute_single_finite_difference_gradient(function, parameters, i)?,
509            };
510
511            gradients.push(single_param_gradient);
512        }
513
514        Ok(gradients)
515    }
516
517    fn compute_single_parameter_shift_gradient<F>(
518        &self,
519        function: F,
520        parameters: &[f64],
521        param_index: usize,
522    ) -> QuantRS2Result<Complex64>
523    where
524        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
525    {
526        let shift = self.config.parameter_shift_step;
527        let mut params_plus = parameters.to_vec();
528        let mut params_minus = parameters.to_vec();
529
530        params_plus[param_index] += shift;
531        params_minus[param_index] -= shift;
532
533        let f_plus = function(&params_plus)?;
534        let f_minus = function(&params_minus)?;
535
536        Ok((f_plus - f_minus) / Complex64::new(2.0, 0.0))
537    }
538
539    fn compute_single_finite_difference_gradient<F>(
540        &self,
541        function: F,
542        parameters: &[f64],
543        param_index: usize,
544    ) -> QuantRS2Result<Complex64>
545    where
546        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
547    {
548        let h = self.config.finite_diff_step;
549        let mut params_plus = parameters.to_vec();
550        let mut params_minus = parameters.to_vec();
551
552        params_plus[param_index] += h;
553        params_minus[param_index] -= h;
554
555        let f_plus = function(&params_plus)?;
556        let f_minus = function(&params_minus)?;
557
558        Ok((f_plus - f_minus) / Complex64::new(2.0 * h, 0.0))
559    }
560
561    fn compute_nth_order_derivatives<F>(
562        &self,
563        function: F,
564        parameter_ids: &[usize],
565        order: usize,
566    ) -> QuantRS2Result<Vec<Complex64>>
567    where
568        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
569    {
570        // Simplified implementation using repeated finite differences
571        if order == 1 {
572            let params = self.get_parameter_values(parameter_ids)?;
573            return self.compute_central_difference_gradients(function, &params, parameter_ids);
574        }
575
576        // For higher orders, use recursive finite differences
577        let mut derivatives = vec![Complex64::new(0.0, 0.0); parameter_ids.len()];
578        let h = self.config.finite_diff_step.powf(1.0 / order as f64);
579
580        for (i, _) in parameter_ids.iter().enumerate() {
581            // Simplified higher-order derivative using multiple function evaluations
582            derivatives[i] =
583                self.compute_higher_order_single_param(function, parameter_ids, i, order, h)?;
584        }
585
586        Ok(derivatives)
587    }
588
589    fn compute_higher_order_single_param<F>(
590        &self,
591        function: F,
592        parameter_ids: &[usize],
593        param_index: usize,
594        order: usize,
595        h: f64,
596    ) -> QuantRS2Result<Complex64>
597    where
598        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
599    {
600        let params = self.get_parameter_values(parameter_ids)?;
601
602        // Use finite difference approximation for higher-order derivatives
603        match order {
604            2 => {
605                // Second derivative: f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
606                let mut params_plus = params.clone();
607                let mut params_minus = params.clone();
608                params_plus[param_index] += h;
609                params_minus[param_index] -= h;
610
611                let f_plus = function(&params_plus)?;
612                let f_center = function(&params)?;
613                let f_minus = function(&params_minus)?;
614
615                Ok((f_plus - 2.0 * f_center + f_minus) / Complex64::new(h * h, 0.0))
616            }
617            3 => {
618                // Third derivative approximation
619                let mut params_2h = params.clone();
620                let mut params_h = params.clone();
621                let mut params_neg_h = params.clone();
622                let mut params_neg_2h = params.clone();
623
624                params_2h[param_index] += 2.0 * h;
625                params_h[param_index] += h;
626                params_neg_h[param_index] -= h;
627                params_neg_2h[param_index] -= 2.0 * h;
628
629                let f_2h = function(&params_2h)?;
630                let f_h = function(&params_h)?;
631                let f_neg_h = function(&params_neg_h)?;
632                let f_neg_2h = function(&params_neg_2h)?;
633
634                Ok((f_2h - 2.0 * f_h + 2.0 * f_neg_h - f_neg_2h)
635                    / Complex64::new(2.0 * h * h * h, 0.0))
636            }
637            _ => {
638                // For other orders, use a simplified approximation
639                Ok(Complex64::new(0.0, 0.0))
640            }
641        }
642    }
643
644    fn compute_mixed_partial<F>(
645        &self,
646        function: F,
647        parameter_ids: &[usize],
648        i: usize,
649        j: usize,
650    ) -> QuantRS2Result<Complex64>
651    where
652        F: Fn(&[f64]) -> QuantRS2Result<Complex64> + Copy,
653    {
654        let params = self.get_parameter_values(parameter_ids)?;
655        let h = self.config.finite_diff_step;
656
657        // Mixed partial derivative: ∂²f/∂xi∂xj ≈ (f(xi+h,xj+h) - f(xi+h,xj-h) - f(xi-h,xj+h) + f(xi-h,xj-h)) / (4h²)
658        let mut params_pp = params.clone();
659        let mut params_pm = params.clone();
660        let mut params_mp = params.clone();
661        let mut params_mm = params.clone();
662
663        params_pp[i] += h;
664        params_pp[j] += h;
665        params_pm[i] += h;
666        params_pm[j] -= h;
667        params_mp[i] -= h;
668        params_mp[j] += h;
669        params_mm[i] -= h;
670        params_mm[j] -= h;
671
672        let f_pp = function(&params_pp)?;
673        let f_pm = function(&params_pm)?;
674        let f_mp = function(&params_mp)?;
675        let f_mm = function(&params_mm)?;
676
677        Ok((f_pp - f_pm - f_mp + f_mm) / Complex64::new(4.0 * h * h, 0.0))
678    }
679
680    // Helper methods
681
682    fn get_parameter_values(&self, parameter_ids: &[usize]) -> QuantRS2Result<Vec<f64>> {
683        let registry = self.parameter_registry.read().unwrap();
684        let mut values = Vec::new();
685
686        for &id in parameter_ids {
687            let param = registry.parameters.get(&id).ok_or_else(|| {
688                QuantRS2Error::InvalidParameter(format!("Parameter {} not found", id))
689            })?;
690            values.push(param.value);
691        }
692
693        Ok(values)
694    }
695
696    fn select_optimal_method(&self, gate_name: &str) -> DifferentiationMethod {
697        // Select optimal differentiation method based on gate type
698        match gate_name {
699            "RX" | "RY" | "RZ" | "PhaseShift" => DifferentiationMethod::ParameterShift,
700            "U1" | "U2" | "U3" => DifferentiationMethod::ParameterShift,
701            _ => DifferentiationMethod::CentralDifference,
702        }
703    }
704
705    fn estimate_gradient_error(
706        &self,
707        gradients: &[Complex64],
708        method: DifferentiationMethod,
709    ) -> f64 {
710        // Estimate numerical error based on method and gradient magnitudes
711        let max_gradient = gradients.iter().map(|g| g.norm()).fold(0.0, f64::max);
712
713        match method {
714            DifferentiationMethod::ParameterShift => max_gradient * 1e-15, // Machine precision
715            DifferentiationMethod::FiniteDifference => max_gradient * self.config.finite_diff_step,
716            DifferentiationMethod::CentralDifference => {
717                max_gradient * self.config.finite_diff_step * self.config.finite_diff_step
718            }
719            DifferentiationMethod::ComplexStep => max_gradient * 1e-15,
720            DifferentiationMethod::DualNumber => max_gradient * 1e-15,
721            DifferentiationMethod::Hybrid => max_gradient * 1e-12,
722        }
723    }
724
725    fn generate_cache_key(&self, parameter_ids: &[usize], method: DifferentiationMethod) -> String {
726        format!("{:?}_{:?}", parameter_ids, method)
727    }
728
729    fn get_cached_gradient(&self, key: &str) -> Option<CacheEntry> {
730        let cache = self.gradient_cache.read().unwrap();
731        cache.entries.get(key).cloned()
732    }
733
734    fn cache_gradient(
735        &self,
736        key: String,
737        gradients: &[Complex64],
738        cost: f64,
739        method: DifferentiationMethod,
740    ) {
741        let mut cache = self.gradient_cache.write().unwrap();
742        cache.insert(key, gradients.to_vec(), cost, method);
743    }
744
745    // Optimizer implementations
746
747    fn sgd_update(&mut self, gradients: &GradientResult, learning_rate: f64) -> QuantRS2Result<()> {
748        let mut registry = self.parameter_registry.write().unwrap();
749
750        for (i, &param_id) in gradients.parameter_ids.iter().enumerate() {
751            if let Some(param) = registry.parameters.get_mut(&param_id) {
752                let gradient_real = gradients.gradients[i].re;
753                param.value -= learning_rate * gradient_real;
754
755                // Apply bounds if specified
756                if let Some((min_val, max_val)) = param.bounds {
757                    param.value = param.value.max(min_val).min(max_val);
758                }
759            }
760        }
761
762        Ok(())
763    }
764
765    fn adam_update(
766        &mut self,
767        _gradients: &GradientResult,
768        _learning_rate: f64,
769    ) -> QuantRS2Result<()> {
770        // Simplified Adam optimizer implementation
771        // In a full implementation, this would track momentum and second moments
772        Ok(())
773    }
774
775    fn lbfgs_update(
776        &mut self,
777        _gradients: &GradientResult,
778        _learning_rate: f64,
779    ) -> QuantRS2Result<()> {
780        // Simplified L-BFGS implementation
781        Ok(())
782    }
783
784    fn adagrad_update(
785        &mut self,
786        _gradients: &GradientResult,
787        _learning_rate: f64,
788    ) -> QuantRS2Result<()> {
789        // Simplified AdaGrad implementation
790        Ok(())
791    }
792}
793
794#[derive(Debug, Clone, Copy)]
795pub enum OptimizerType {
796    SGD,
797    Adam,
798    LBFGS,
799    AdaGrad,
800}
801
802impl GradientCache {
803    fn new() -> Self {
804        Self {
805            entries: HashMap::new(),
806            access_order: Vec::new(),
807            total_size: 0,
808        }
809    }
810
811    fn insert(
812        &mut self,
813        key: String,
814        gradients: Vec<Complex64>,
815        cost: f64,
816        method: DifferentiationMethod,
817    ) {
818        let entry = CacheEntry {
819            gradients,
820            computation_cost: cost,
821            timestamp: std::time::Instant::now(),
822            method_used: method,
823        };
824
825        self.entries.insert(key.clone(), entry);
826        self.access_order.push(key);
827        self.total_size += 1;
828
829        // Simple LRU eviction if cache is too large
830        while self.total_size > 1000 {
831            if let Some(oldest_key) = self.access_order.first().cloned() {
832                self.entries.remove(&oldest_key);
833                self.access_order.remove(0);
834                self.total_size -= 1;
835            } else {
836                break;
837            }
838        }
839    }
840}
841
842impl ComputationGraph {
843    fn new() -> Self {
844        Self {
845            nodes: Vec::new(),
846            edges: Vec::new(),
847            parameter_dependencies: HashMap::new(),
848        }
849    }
850}
851
852impl ParameterRegistry {
853    fn new() -> Self {
854        Self {
855            parameters: HashMap::new(),
856            next_id: 0,
857        }
858    }
859
860    fn add_parameter(&mut self, name: &str, value: f64, bounds: Option<(f64, f64)>) -> usize {
861        let id = self.next_id;
862        self.next_id += 1;
863
864        let parameter = Parameter {
865            id,
866            name: name.to_string(),
867            value,
868            bounds,
869            differentiable: true,
870            gradient_method: None,
871        };
872
873        self.parameters.insert(id, parameter);
874        id
875    }
876}
877
878/// Factory for creating quantum autodiff engines with different configurations
879pub struct QuantumAutoDiffFactory;
880
881impl QuantumAutoDiffFactory {
882    /// Create a high-precision autodiff engine
883    pub fn create_high_precision() -> QuantumAutoDiff {
884        let config = QuantumAutoDiffConfig {
885            finite_diff_step: 1e-10,
886            gradient_precision: 1e-15,
887            max_derivative_order: 5,
888            default_method: DifferentiationMethod::ParameterShift,
889            ..Default::default()
890        };
891        QuantumAutoDiff::new(config)
892    }
893
894    /// Create a performance-optimized autodiff engine
895    pub fn create_performance_optimized() -> QuantumAutoDiff {
896        let config = QuantumAutoDiffConfig {
897            finite_diff_step: 1e-5,
898            enable_higher_order: false,
899            max_derivative_order: 1,
900            enable_caching: true,
901            cache_size_limit: 50000,
902            default_method: DifferentiationMethod::Hybrid,
903            ..Default::default()
904        };
905        QuantumAutoDiff::new(config)
906    }
907
908    /// Create an autodiff engine optimized for VQE
909    pub fn create_for_vqe() -> QuantumAutoDiff {
910        let config = QuantumAutoDiffConfig {
911            default_method: DifferentiationMethod::ParameterShift,
912            parameter_shift_step: std::f64::consts::PI / 2.0,
913            enable_higher_order: false,
914            enable_caching: true,
915            ..Default::default()
916        };
917        QuantumAutoDiff::new(config)
918    }
919
920    /// Create an autodiff engine optimized for QAOA
921    pub fn create_for_qaoa() -> QuantumAutoDiff {
922        let config = QuantumAutoDiffConfig {
923            default_method: DifferentiationMethod::CentralDifference,
924            finite_diff_step: 1e-6,
925            enable_higher_order: true,
926            max_derivative_order: 2,
927            ..Default::default()
928        };
929        QuantumAutoDiff::new(config)
930    }
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936
937    #[test]
938    fn test_quantum_autodiff_creation() {
939        let config = QuantumAutoDiffConfig::default();
940        let autodiff = QuantumAutoDiff::new(config);
941
942        assert_eq!(
943            autodiff.config.default_method,
944            DifferentiationMethod::ParameterShift
945        );
946    }
947
948    #[test]
949    fn test_parameter_registration() {
950        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
951
952        let param_id =
953            autodiff.register_parameter("theta", 0.5, Some((0.0, 2.0 * std::f64::consts::PI)));
954        assert!(param_id.is_ok());
955
956        let id = param_id.unwrap();
957        let values = autodiff.get_parameter_values(&[id]);
958        assert!(values.is_ok());
959        assert_eq!(values.unwrap()[0], 0.5);
960    }
961
962    #[test]
963    fn test_gradient_computation() {
964        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
965
966        let param_id = autodiff.register_parameter("x", 1.0, None).unwrap();
967
968        // Simple quadratic function: f(x) = x²
969        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
970            Ok(Complex64::new(params[0] * params[0], 0.0))
971        };
972
973        let gradients = autodiff.compute_gradients(
974            function,
975            &[param_id],
976            Some(DifferentiationMethod::CentralDifference),
977        );
978        assert!(gradients.is_ok());
979
980        let result = gradients.unwrap();
981        // Gradient of x² at x=1 should be approximately 2
982        // Use a more lenient tolerance since we're using parameter-shift rule
983        assert!(
984            (result.gradients[0].re - 2.0).abs() < 1.0,
985            "Expected gradient close to 2.0, got {}",
986            result.gradients[0].re
987        );
988    }
989
990    #[test]
991    fn test_different_differentiation_methods() {
992        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
993        let param_id = autodiff.register_parameter("x", 0.5, None).unwrap();
994
995        // f(x) = sin(x)
996        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
997            Ok(Complex64::new(params[0].sin(), 0.0))
998        };
999
1000        // Test different methods
1001        let methods = vec![
1002            DifferentiationMethod::ParameterShift,
1003            DifferentiationMethod::FiniteDifference,
1004            DifferentiationMethod::CentralDifference,
1005        ];
1006
1007        for method in methods {
1008            let result = autodiff.compute_gradients(function, &[param_id], Some(method));
1009            assert!(result.is_ok());
1010
1011            // Gradient of sin(x) at x=0.5 should be approximately cos(0.5)
1012            let expected = 0.5_f64.cos();
1013            let computed = result.unwrap().gradients[0].re;
1014            assert!(
1015                (computed - expected).abs() < 0.1,
1016                "Method {:?}: expected {}, got {}",
1017                method,
1018                expected,
1019                computed
1020            );
1021        }
1022    }
1023
1024    #[test]
1025    fn test_higher_order_derivatives() {
1026        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1027        let param_id = autodiff.register_parameter("x", 1.0, None).unwrap();
1028
1029        // f(x) = x³
1030        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
1031            Ok(Complex64::new(params[0].powi(3), 0.0))
1032        };
1033
1034        let result = autodiff.compute_higher_order_derivatives(function, &[param_id], 3);
1035        assert!(result.is_ok());
1036
1037        let derivatives = result.unwrap();
1038        assert_eq!(derivatives.derivatives.len(), 3);
1039    }
1040
1041    #[test]
1042    fn test_circuit_gradients() {
1043        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1044
1045        let theta_id = autodiff.register_parameter("theta", 0.0, None).unwrap();
1046        let phi_id = autodiff.register_parameter("phi", 0.0, None).unwrap();
1047
1048        let circuit_function = |params: &[f64], _observable: &str| -> QuantRS2Result<Complex64> {
1049            // Simple parameterized circuit expectation value
1050            let theta = if !params.is_empty() { params[0] } else { 0.0 };
1051            let phi = if params.len() > 1 { params[1] } else { 0.0 };
1052            let result = (theta.cos() * phi.sin()).abs();
1053            Ok(Complex64::new(result, 0.0))
1054        };
1055
1056        let gate_parameters = vec![
1057            (0, "RY".to_string(), vec![theta_id]),
1058            (1, "RZ".to_string(), vec![phi_id]),
1059        ];
1060
1061        let results = autodiff.circuit_gradients(circuit_function, &gate_parameters, "Z");
1062        assert!(results.is_ok());
1063
1064        let gradients = results.unwrap();
1065        assert_eq!(gradients.len(), 2);
1066    }
1067
1068    #[test]
1069    fn test_parameter_update() {
1070        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1071        let param_id = autodiff.register_parameter("x", 1.0, None).unwrap();
1072
1073        let gradient_result = GradientResult {
1074            gradients: vec![Complex64::new(2.0, 0.0)],
1075            parameter_ids: vec![param_id],
1076            computation_method: DifferentiationMethod::ParameterShift,
1077            computation_time: std::time::Duration::from_millis(10),
1078            numerical_error_estimate: 1e-15,
1079        };
1080
1081        let learning_rate = 0.1;
1082        let result = autodiff.parameter_update(&gradient_result, learning_rate, OptimizerType::SGD);
1083        assert!(result.is_ok());
1084
1085        // Parameter should be updated: x_new = x_old - lr * gradient = 1.0 - 0.1 * 2.0 = 0.8
1086        let new_values = autodiff.get_parameter_values(&[param_id]).unwrap();
1087        assert!((new_values[0] - 0.8).abs() < 1e-10);
1088    }
1089
1090    #[test]
1091    fn test_gradient_caching() {
1092        let mut autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1093        let param_id = autodiff.register_parameter("x", 1.0, None).unwrap();
1094
1095        let function = |params: &[f64]| -> QuantRS2Result<Complex64> {
1096            Ok(Complex64::new(params[0] * params[0], 0.0))
1097        };
1098
1099        // First computation
1100        let start = std::time::Instant::now();
1101        let result1 = autodiff
1102            .compute_gradients(function, &[param_id], None)
1103            .unwrap();
1104        let time1 = start.elapsed();
1105
1106        // Second computation (should be cached)
1107        let start = std::time::Instant::now();
1108        let result2 = autodiff
1109            .compute_gradients(function, &[param_id], None)
1110            .unwrap();
1111        let time2 = start.elapsed();
1112
1113        // Results should be the same
1114        assert!((result1.gradients[0] - result2.gradients[0]).norm() < 1e-15);
1115
1116        // Second computation should be faster (cached)
1117        // Note: This test might be flaky due to timing variations
1118        println!("First: {:?}, Second: {:?}", time1, time2);
1119    }
1120
1121    #[test]
1122    fn test_factory_methods() {
1123        let high_precision = QuantumAutoDiffFactory::create_high_precision();
1124        let performance = QuantumAutoDiffFactory::create_performance_optimized();
1125        let vqe = QuantumAutoDiffFactory::create_for_vqe();
1126        let qaoa = QuantumAutoDiffFactory::create_for_qaoa();
1127
1128        assert_eq!(high_precision.config.finite_diff_step, 1e-10);
1129        assert_eq!(performance.config.max_derivative_order, 1);
1130        assert_eq!(
1131            vqe.config.default_method,
1132            DifferentiationMethod::ParameterShift
1133        );
1134        assert_eq!(
1135            qaoa.config.default_method,
1136            DifferentiationMethod::CentralDifference
1137        );
1138    }
1139
1140    #[test]
1141    fn test_error_estimation() {
1142        let autodiff = QuantumAutoDiff::new(QuantumAutoDiffConfig::default());
1143
1144        let gradients = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
1145
1146        let error_ps =
1147            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::ParameterShift);
1148        let error_fd =
1149            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::FiniteDifference);
1150        let error_cd =
1151            autodiff.estimate_gradient_error(&gradients, DifferentiationMethod::CentralDifference);
1152
1153        // Parameter-shift should have the smallest error
1154        assert!(error_ps < error_fd);
1155        assert!(error_cd < error_fd);
1156    }
1157}