quantrs2_device/quantum_ml/
gradients.rs

1//! Quantum Gradient Computation
2//!
3//! This module implements various methods for computing gradients of quantum circuits,
4//! including parameter shift rules, finite differences, and quantum natural gradients.
5
6use super::*;
7use crate::continuous_variable::Complex;
8use crate::{CircuitExecutor, CircuitResult, DeviceError, DeviceResult, QuantumDevice};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14/// Quantum gradient calculator
15pub struct QuantumGradientCalculator {
16    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
17    config: GradientConfig,
18    method: GradientMethod,
19}
20
21/// Configuration for gradient computation
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GradientConfig {
24    /// Method for computing gradients
25    pub method: GradientMethod,
26    /// Number of shots per evaluation
27    pub shots: usize,
28    /// Finite difference step size
29    pub finite_diff_step: f64,
30    /// Parameter shift rule shift amount
31    pub shift_amount: f64,
32    /// Use error mitigation
33    pub use_error_mitigation: bool,
34    /// Parallel gradient computation
35    pub parallel_execution: bool,
36    /// Gradient clipping threshold
37    pub gradient_clipping: Option<f64>,
38}
39
40impl Default for GradientConfig {
41    fn default() -> Self {
42        Self {
43            method: GradientMethod::ParameterShift,
44            shots: 1024,
45            finite_diff_step: 1e-4,
46            shift_amount: std::f64::consts::PI / 2.0,
47            use_error_mitigation: true,
48            parallel_execution: true,
49            gradient_clipping: Some(1.0),
50        }
51    }
52}
53
54impl QuantumGradientCalculator {
55    /// Create a new gradient calculator
56    pub fn new(
57        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
58        config: GradientConfig,
59    ) -> DeviceResult<Self> {
60        let method = config.method.clone();
61
62        Ok(Self {
63            device,
64            config,
65            method,
66        })
67    }
68
69    /// Compute gradients for a parameterized quantum circuit
70    pub async fn compute_gradients(
71        &self,
72        circuit: ParameterizedQuantumCircuit,
73        parameters: Vec<f64>,
74    ) -> DeviceResult<Vec<f64>> {
75        match self.method {
76            GradientMethod::ParameterShift => {
77                self.parameter_shift_gradients(circuit, parameters).await
78            }
79            GradientMethod::FiniteDifference => {
80                self.finite_difference_gradients(circuit, parameters).await
81            }
82            GradientMethod::LinearCombination => {
83                self.linear_combination_gradients(circuit, parameters).await
84            }
85            GradientMethod::QuantumNaturalGradient => {
86                self.quantum_natural_gradients(circuit, parameters).await
87            }
88            GradientMethod::Adjoint => self.adjoint_gradients(circuit, parameters).await,
89        }
90    }
91
92    /// Compute gradients using parameter shift rule
93    async fn parameter_shift_gradients(
94        &self,
95        circuit: ParameterizedQuantumCircuit,
96        parameters: Vec<f64>,
97    ) -> DeviceResult<Vec<f64>> {
98        let mut gradients = vec![0.0; parameters.len()];
99        let shift = self.config.shift_amount;
100
101        if self.config.parallel_execution {
102            // Parallel computation of all parameter shifts
103            let mut tasks = Vec::new();
104
105            for i in 0..parameters.len() {
106                let mut params_plus = parameters.clone();
107                let mut params_minus = parameters.clone();
108                params_plus[i] += shift;
109                params_minus[i] -= shift;
110
111                let circuit_plus = circuit.clone();
112                let circuit_minus = circuit.clone();
113                let device_plus = self.device.clone();
114                let device_minus = self.device.clone();
115                let shots = self.config.shots;
116
117                let task_plus = tokio::spawn(async move {
118                    let circuit_eval =
119                        Self::evaluate_circuit_with_params(&circuit_plus, &params_plus)?;
120                    let device = device_plus.read().await;
121                    Self::execute_circuit_helper(&*device, &circuit_eval, shots).await
122                });
123
124                let task_minus = tokio::spawn(async move {
125                    let circuit_eval =
126                        Self::evaluate_circuit_with_params(&circuit_minus, &params_minus)?;
127                    let device = device_minus.read().await;
128                    Self::execute_circuit_helper(&*device, &circuit_eval, shots).await
129                });
130
131                tasks.push((i, task_plus, task_minus));
132            }
133
134            // Collect results
135            for (param_idx, task_plus, task_minus) in tasks {
136                let result_plus = task_plus
137                    .await
138                    .map_err(|e| DeviceError::InvalidInput(format!("Task error: {e}")))??;
139                let result_minus = task_minus
140                    .await
141                    .map_err(|e| DeviceError::InvalidInput(format!("Task error: {e}")))??;
142
143                let expectation_plus = self.compute_expectation_value(&result_plus)?;
144                let expectation_minus = self.compute_expectation_value(&result_minus)?;
145
146                gradients[param_idx] = (expectation_plus - expectation_minus) / 2.0;
147            }
148        } else {
149            // Sequential computation
150            for i in 0..parameters.len() {
151                let mut params_plus = parameters.clone();
152                let mut params_minus = parameters.clone();
153                params_plus[i] += shift;
154                params_minus[i] -= shift;
155
156                let circuit_plus = Self::evaluate_circuit_with_params(&circuit, &params_plus)?;
157                let circuit_minus = Self::evaluate_circuit_with_params(&circuit, &params_minus)?;
158
159                let device = self.device.read().await;
160                let result_plus =
161                    Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots)
162                        .await?;
163                let result_minus =
164                    Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots)
165                        .await?;
166
167                let expectation_plus = self.compute_expectation_value(&result_plus)?;
168                let expectation_minus = self.compute_expectation_value(&result_minus)?;
169
170                gradients[i] = (expectation_plus - expectation_minus) / 2.0;
171            }
172        }
173
174        // Apply gradient clipping if specified
175        if let Some(clip_value) = self.config.gradient_clipping {
176            for grad in &mut gradients {
177                *grad = grad.clamp(-clip_value, clip_value);
178            }
179        }
180
181        Ok(gradients)
182    }
183
184    /// Compute gradients using finite differences
185    async fn finite_difference_gradients(
186        &self,
187        circuit: ParameterizedQuantumCircuit,
188        parameters: Vec<f64>,
189    ) -> DeviceResult<Vec<f64>> {
190        let mut gradients = vec![0.0; parameters.len()];
191        let step = self.config.finite_diff_step;
192
193        for i in 0..parameters.len() {
194            let mut params_plus = parameters.clone();
195            let mut params_minus = parameters.clone();
196            params_plus[i] += step;
197            params_minus[i] -= step;
198
199            let circuit_plus = Self::evaluate_circuit_with_params(&circuit, &params_plus)?;
200            let circuit_minus = Self::evaluate_circuit_with_params(&circuit, &params_minus)?;
201
202            let device = self.device.read().await;
203            let result_plus =
204                Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
205            let result_minus =
206                Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots).await?;
207
208            let expectation_plus = self.compute_expectation_value(&result_plus)?;
209            let expectation_minus = self.compute_expectation_value(&result_minus)?;
210
211            gradients[i] = (expectation_plus - expectation_minus) / (2.0 * step);
212        }
213
214        Ok(gradients)
215    }
216
217    /// Compute gradients using linear combination of unitaries (LCU)
218    async fn linear_combination_gradients(
219        &self,
220        circuit: ParameterizedQuantumCircuit,
221        parameters: Vec<f64>,
222    ) -> DeviceResult<Vec<f64>> {
223        // This is a simplified implementation of LCU gradients
224        // In practice, this would decompose the gradient operator into a linear combination
225        let mut gradients = vec![0.0; parameters.len()];
226
227        for i in 0..parameters.len() {
228            // Simplified: use a small finite difference as approximation
229            let step = 1e-3;
230            let mut params_plus = parameters.clone();
231            params_plus[i] += step;
232
233            let circuit_original = Self::evaluate_circuit_with_params(&circuit, &parameters)?;
234            let circuit_plus = Self::evaluate_circuit_with_params(&circuit, &params_plus)?;
235
236            let device = self.device.read().await;
237            let result_original =
238                Self::execute_circuit_helper(&*device, &circuit_original, self.config.shots)
239                    .await?;
240            let result_plus =
241                Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
242
243            let expectation_original = self.compute_expectation_value(&result_original)?;
244            let expectation_plus = self.compute_expectation_value(&result_plus)?;
245
246            gradients[i] = (expectation_plus - expectation_original) / step;
247        }
248
249        Ok(gradients)
250    }
251
252    /// Compute quantum natural gradients
253    async fn quantum_natural_gradients(
254        &self,
255        circuit: ParameterizedQuantumCircuit,
256        parameters: Vec<f64>,
257    ) -> DeviceResult<Vec<f64>> {
258        // First compute regular gradients
259        let regular_gradients = self
260            .parameter_shift_gradients(circuit.clone(), parameters.clone())
261            .await?;
262
263        // Compute quantum Fisher information matrix (simplified)
264        let fisher_matrix = self
265            .compute_quantum_fisher_information(&circuit, &parameters)
266            .await?;
267
268        // Solve Fisher^{-1} * gradient
269        let natural_gradients = self.solve_linear_system(&fisher_matrix, &regular_gradients)?;
270
271        Ok(natural_gradients)
272    }
273
274    /// Compute gradients using adjoint method (simplified)
275    async fn adjoint_gradients(
276        &self,
277        circuit: ParameterizedQuantumCircuit,
278        parameters: Vec<f64>,
279    ) -> DeviceResult<Vec<f64>> {
280        // This is a placeholder for adjoint gradient computation
281        // Real implementation would require access to quantum state amplitudes
282        // For now, fall back to parameter shift rule
283        self.parameter_shift_gradients(circuit, parameters).await
284    }
285
286    /// Compute quantum Fisher information matrix
287    async fn compute_quantum_fisher_information(
288        &self,
289        circuit: &ParameterizedQuantumCircuit,
290        parameters: &[f64],
291    ) -> DeviceResult<Vec<Vec<f64>>> {
292        let n_params = parameters.len();
293        let mut fisher_matrix = vec![vec![0.0; n_params]; n_params];
294        let shift = std::f64::consts::PI / 2.0;
295
296        for i in 0..n_params {
297            for j in i..n_params {
298                if i == j {
299                    // Diagonal elements: Var[∂ψ/∂θᵢ]
300                    let mut params_plus = parameters.to_vec();
301                    let mut params_minus = parameters.to_vec();
302                    params_plus[i] += shift;
303                    params_minus[i] -= shift;
304
305                    let circuit_plus = Self::evaluate_circuit_with_params(circuit, &params_plus)?;
306                    let circuit_minus = Self::evaluate_circuit_with_params(circuit, &params_minus)?;
307
308                    let device = self.device.read().await;
309                    let result_plus =
310                        Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots)
311                            .await?;
312                    let result_minus =
313                        Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots)
314                            .await?;
315
316                    let overlap = self.compute_state_overlap(&result_plus, &result_minus)?;
317                    fisher_matrix[i][j] = (1.0 - overlap.real) / 2.0;
318                } else {
319                    // Off-diagonal elements: Re[⟨∂ψ/∂θᵢ|∂ψ/∂θⱼ⟩]
320                    // Simplified computation
321                    fisher_matrix[i][j] = 0.0;
322                    fisher_matrix[j][i] = fisher_matrix[i][j];
323                }
324            }
325        }
326
327        // Add regularization to ensure invertibility
328        for i in 0..n_params {
329            fisher_matrix[i][i] += 1e-6;
330        }
331
332        Ok(fisher_matrix)
333    }
334
335    /// Compute overlap between quantum states (simplified)
336    fn compute_state_overlap(
337        &self,
338        result1: &CircuitResult,
339        result2: &CircuitResult,
340    ) -> DeviceResult<Complex> {
341        // This is a simplified overlap computation based on measurement statistics
342        // Real implementation would require access to quantum state amplitudes
343
344        let mut overlap_real = 0.0;
345        let total_shots1 = result1.shots as f64;
346        let total_shots2 = result2.shots as f64;
347
348        for (bitstring, count1) in &result1.counts {
349            if let Some(count2) = result2.counts.get(bitstring) {
350                let prob1 = *count1 as f64 / total_shots1;
351                let prob2 = *count2 as f64 / total_shots2;
352                overlap_real += (prob1 * prob2).sqrt();
353            }
354        }
355
356        Ok(Complex::new(overlap_real, 0.0))
357    }
358
359    /// Solve linear system Ax = b
360    fn solve_linear_system(&self, matrix: &[Vec<f64>], vector: &[f64]) -> DeviceResult<Vec<f64>> {
361        let n = matrix.len();
362        if n != vector.len() {
363            return Err(DeviceError::InvalidInput(
364                "Matrix and vector dimensions don't match".to_string(),
365            ));
366        }
367
368        // Simple Gaussian elimination (for small systems)
369        let mut augmented = matrix
370            .iter()
371            .zip(vector.iter())
372            .map(|(row, &b)| {
373                let mut aug_row = row.clone();
374                aug_row.push(b);
375                aug_row
376            })
377            .collect::<Vec<_>>();
378
379        // Forward elimination
380        for i in 0..n {
381            // Find pivot
382            let mut max_row = i;
383            for k in i + 1..n {
384                if augmented[k][i].abs() > augmented[max_row][i].abs() {
385                    max_row = k;
386                }
387            }
388            augmented.swap(i, max_row);
389
390            // Check for singularity
391            if augmented[i][i].abs() < 1e-10 {
392                return Err(DeviceError::InvalidInput(
393                    "Singular matrix in linear system".to_string(),
394                ));
395            }
396
397            // Eliminate
398            for k in i + 1..n {
399                let factor = augmented[k][i] / augmented[i][i];
400                for j in i..=n {
401                    augmented[k][j] -= factor * augmented[i][j];
402                }
403            }
404        }
405
406        // Back substitution
407        let mut solution = vec![0.0; n];
408        for i in (0..n).rev() {
409            solution[i] = augmented[i][n];
410            for j in i + 1..n {
411                solution[i] -= augmented[i][j] * solution[j];
412            }
413            solution[i] /= augmented[i][i];
414        }
415
416        Ok(solution)
417    }
418
419    /// Execute a circuit on the quantum device
420    async fn execute_circuit_helper(
421        device: &(dyn QuantumDevice + Send + Sync),
422        circuit: &ParameterizedQuantumCircuit,
423        shots: usize,
424    ) -> DeviceResult<CircuitResult> {
425        // For now, return a mock result since we can't execute circuits directly
426        // In a real implementation, this would need proper circuit execution
427        let mut counts = std::collections::HashMap::new();
428        counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
429        counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
430
431        Ok(CircuitResult {
432            counts,
433            shots,
434            metadata: std::collections::HashMap::new(),
435        })
436    }
437
438    /// Evaluate a parameterized circuit with specific parameter values
439    fn evaluate_circuit_with_params(
440        circuit: &ParameterizedQuantumCircuit,
441        parameters: &[f64],
442    ) -> DeviceResult<ParameterizedQuantumCircuit> {
443        // This would substitute parameters into the circuit
444        // For now, return a copy (implementation would be more sophisticated)
445        Ok(circuit.clone())
446    }
447
448    /// Compute expectation value from measurement results
449    fn compute_expectation_value(&self, result: &CircuitResult) -> DeviceResult<f64> {
450        // Simple expectation value: average number of 1s
451        let mut expectation = 0.0;
452        let total_shots = result.shots as f64;
453
454        for (bitstring, count) in &result.counts {
455            let ones_count = bitstring.chars().filter(|&c| c == '1').count();
456            let probability = *count as f64 / total_shots;
457            expectation += ones_count as f64 * probability;
458        }
459
460        Ok(expectation)
461    }
462
463    /// Compute gradients with respect to a specific observable
464    pub async fn compute_observable_gradients(
465        &self,
466        circuit: ParameterizedQuantumCircuit,
467        parameters: Vec<f64>,
468        observable: Observable,
469    ) -> DeviceResult<Vec<f64>> {
470        match self.method {
471            GradientMethod::ParameterShift => {
472                self.parameter_shift_observable_gradients(circuit, parameters, observable)
473                    .await
474            }
475            _ => {
476                // For other methods, use default expectation value
477                self.compute_gradients(circuit, parameters).await
478            }
479        }
480    }
481
482    /// Compute gradients with respect to a specific observable using parameter shift
483    async fn parameter_shift_observable_gradients(
484        &self,
485        circuit: ParameterizedQuantumCircuit,
486        parameters: Vec<f64>,
487        observable: Observable,
488    ) -> DeviceResult<Vec<f64>> {
489        let mut gradients = vec![0.0; parameters.len()];
490        let shift = self.config.shift_amount;
491
492        for i in 0..parameters.len() {
493            let mut params_plus = parameters.clone();
494            let mut params_minus = parameters.clone();
495            params_plus[i] += shift;
496            params_minus[i] -= shift;
497
498            let circuit_plus = Self::evaluate_circuit_with_params(&circuit, &params_plus)?;
499            let circuit_minus = Self::evaluate_circuit_with_params(&circuit, &params_minus)?;
500
501            let device = self.device.read().await;
502            let result_plus =
503                Self::execute_circuit_helper(&*device, &circuit_plus, self.config.shots).await?;
504            let result_minus =
505                Self::execute_circuit_helper(&*device, &circuit_minus, self.config.shots).await?;
506
507            let expectation_plus =
508                self.compute_observable_expectation(&result_plus, &observable)?;
509            let expectation_minus =
510                self.compute_observable_expectation(&result_minus, &observable)?;
511
512            gradients[i] = (expectation_plus - expectation_minus) / 2.0;
513        }
514
515        Ok(gradients)
516    }
517
518    /// Compute expectation value of an observable
519    fn compute_observable_expectation(
520        &self,
521        result: &CircuitResult,
522        observable: &Observable,
523    ) -> DeviceResult<f64> {
524        let mut expectation = 0.0;
525        let total_shots = result.shots as f64;
526
527        for (bitstring, count) in &result.counts {
528            let probability = *count as f64 / total_shots;
529            let eigenvalue = observable.evaluate_bitstring(bitstring)?;
530            expectation += probability * eigenvalue;
531        }
532
533        Ok(expectation)
534    }
535}
536
537/// Observable for expectation value computation
538#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct Observable {
540    pub terms: Vec<ObservableTerm>,
541}
542
543/// Single term in an observable
544#[derive(Debug, Clone, Serialize, Deserialize)]
545pub struct ObservableTerm {
546    pub coefficient: f64,
547    pub pauli_string: Vec<(usize, PauliOperator)>, // (qubit_index, pauli_operator)
548}
549
550impl Observable {
551    /// Create a Z observable on a single qubit
552    pub fn single_z(qubit: usize) -> Self {
553        Self {
554            terms: vec![ObservableTerm {
555                coefficient: 1.0,
556                pauli_string: vec![(qubit, PauliOperator::Z)],
557            }],
558        }
559    }
560
561    /// Create an all-Z observable (sum of Z on all qubits)
562    pub fn all_z(num_qubits: usize) -> Self {
563        let terms = (0..num_qubits)
564            .map(|i| ObservableTerm {
565                coefficient: 1.0,
566                pauli_string: vec![(i, PauliOperator::Z)],
567            })
568            .collect();
569
570        Self { terms }
571    }
572
573    /// Evaluate observable for a given bitstring
574    pub fn evaluate_bitstring(&self, bitstring: &str) -> DeviceResult<f64> {
575        let mut value = 0.0;
576
577        for term in &self.terms {
578            let mut term_value = term.coefficient;
579
580            for (qubit_idx, pauli_op) in &term.pauli_string {
581                if let Some(bit_char) = bitstring.chars().nth(*qubit_idx) {
582                    let bit_value = if bit_char == '1' { -1.0 } else { 1.0 };
583
584                    match pauli_op {
585                        PauliOperator::Z => term_value *= bit_value,
586                        PauliOperator::I => {} // Identity
587                        PauliOperator::X | PauliOperator::Y => {
588                            // Would need basis rotation for X/Y measurements
589                            return Err(DeviceError::InvalidInput(
590                                "X and Y Pauli measurements require basis rotation".to_string(),
591                            ));
592                        }
593                    }
594                }
595            }
596
597            value += term_value;
598        }
599
600        Ok(value)
601    }
602}
603
604/// Gradient computation utilities
605pub struct GradientUtils;
606
607impl GradientUtils {
608    /// Estimate gradients using central differences
609    pub fn central_difference(
610        f: impl Fn(&[f64]) -> f64,
611        parameters: &[f64],
612        step_size: f64,
613    ) -> Vec<f64> {
614        let mut gradients = vec![0.0; parameters.len()];
615
616        for i in 0..parameters.len() {
617            let mut params_plus = parameters.to_vec();
618            let mut params_minus = parameters.to_vec();
619            params_plus[i] += step_size;
620            params_minus[i] -= step_size;
621
622            let f_plus = f(&params_plus);
623            let f_minus = f(&params_minus);
624
625            gradients[i] = (f_plus - f_minus) / (2.0 * step_size);
626        }
627
628        gradients
629    }
630
631    /// Clip gradients to prevent exploding gradients
632    pub fn clip_gradients(gradients: &mut [f64], max_norm: f64) {
633        let norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
634        if norm > max_norm {
635            let scale = max_norm / norm;
636            for grad in gradients {
637                *grad *= scale;
638            }
639        }
640    }
641
642    /// Apply momentum to gradient updates
643    pub fn apply_momentum(
644        gradients: &[f64],
645        momentum_buffer: &mut Vec<f64>,
646        momentum: f64,
647    ) -> Vec<f64> {
648        if momentum_buffer.len() != gradients.len() {
649            momentum_buffer.resize(gradients.len(), 0.0);
650        }
651
652        let mut updated_gradients = Vec::with_capacity(gradients.len());
653        for i in 0..gradients.len() {
654            momentum_buffer[i] = momentum.mul_add(momentum_buffer[i], gradients[i]);
655            updated_gradients.push(momentum_buffer[i]);
656        }
657
658        updated_gradients
659    }
660}
661
662/// Create a parameter shift gradient calculator
663pub fn create_parameter_shift_calculator(
664    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
665    shots: usize,
666) -> DeviceResult<QuantumGradientCalculator> {
667    let config = GradientConfig {
668        method: GradientMethod::ParameterShift,
669        shots,
670        ..Default::default()
671    };
672
673    QuantumGradientCalculator::new(device, config)
674}
675
676/// Create a finite difference gradient calculator
677pub fn create_finite_difference_calculator(
678    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
679    step_size: f64,
680) -> DeviceResult<QuantumGradientCalculator> {
681    let config = GradientConfig {
682        method: GradientMethod::FiniteDifference,
683        finite_diff_step: step_size,
684        ..Default::default()
685    };
686
687    QuantumGradientCalculator::new(device, config)
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693    use crate::test_utils::create_mock_quantum_device;
694
695    #[tokio::test]
696    async fn test_gradient_calculator_creation() {
697        let device = create_mock_quantum_device();
698        let calculator = QuantumGradientCalculator::new(device, GradientConfig::default())
699            .expect("QuantumGradientCalculator creation should succeed with default config");
700
701        assert_eq!(calculator.config.method, GradientMethod::ParameterShift);
702        assert_eq!(calculator.config.shots, 1024);
703    }
704
705    #[test]
706    fn test_observable_creation() {
707        let obs = Observable::single_z(0);
708        assert_eq!(obs.terms.len(), 1);
709        assert_eq!(obs.terms[0].coefficient, 1.0);
710
711        let obs_all = Observable::all_z(4);
712        assert_eq!(obs_all.terms.len(), 4);
713    }
714
715    #[test]
716    fn test_observable_evaluation() {
717        let obs = Observable::single_z(0);
718
719        let value_0 = obs
720            .evaluate_bitstring("0")
721            .expect("Observable evaluation should succeed for bitstring '0'");
722        assert_eq!(value_0, 1.0);
723
724        let value_1 = obs
725            .evaluate_bitstring("1")
726            .expect("Observable evaluation should succeed for bitstring '1'");
727        assert_eq!(value_1, -1.0);
728    }
729
730    #[test]
731    fn test_gradient_utils() {
732        let quadratic = |params: &[f64]| params[0] * params[0] + 2.0 * params[1] * params[1];
733        let gradients = GradientUtils::central_difference(quadratic, &[1.0, 2.0], 1e-5);
734
735        // Analytical gradients: [2x, 4y] = [2.0, 8.0]
736        assert!((gradients[0] - 2.0).abs() < 1e-3);
737        assert!((gradients[1] - 8.0).abs() < 1e-3);
738    }
739
740    #[test]
741    fn test_gradient_clipping() {
742        let mut gradients = vec![3.0, 4.0]; // Norm = 5.0
743        GradientUtils::clip_gradients(&mut gradients, 2.0);
744
745        let new_norm = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
746        assert!((new_norm - 2.0).abs() < 1e-10);
747    }
748
749    #[test]
750    fn test_momentum() {
751        let gradients = vec![1.0, 2.0];
752        let mut momentum_buffer = vec![0.5, -0.5];
753
754        let updated = GradientUtils::apply_momentum(&gradients, &mut momentum_buffer, 0.9);
755
756        // Expected: [0.9 * 0.5 + 1.0, 0.9 * (-0.5) + 2.0] = [1.45, 1.55]
757        assert!((updated[0] - 1.45).abs() < 1e-10);
758        assert!((updated[1] - 1.55).abs() < 1e-10);
759    }
760}