quantrs2_core/
variational.rs

1//! Variational quantum gates with automatic differentiation support
2//!
3//! This module provides variational quantum gates whose parameters can be optimized
4//! using gradient-based methods. It includes automatic differentiation for computing
5//! parameter gradients efficiently.
6
7use crate::{
8    error::{QuantRS2Error, QuantRS2Result},
9    gate::GateOp,
10    matrix_ops::{DenseMatrix, QuantumMatrix},
11    qubit::QubitId,
12    register::Register,
13};
14use ndarray::{Array1, Array2};
15use num_complex::Complex;
16use rustc_hash::FxHashMap;
17use std::any::Any;
18use std::f64::consts::PI;
19use std::sync::Arc;
20
21/// Automatic differentiation mode
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum DiffMode {
24    /// Forward-mode automatic differentiation
25    Forward,
26    /// Reverse-mode automatic differentiation (backpropagation)
27    Reverse,
28    /// Parameter shift rule for quantum circuits
29    ParameterShift,
30    /// Finite differences approximation
31    FiniteDiff { epsilon: f64 },
32}
33
34/// Dual number for forward-mode autodiff
35#[derive(Debug, Clone, Copy)]
36pub struct Dual {
37    /// Real part (value)
38    pub real: f64,
39    /// Dual part (derivative)
40    pub dual: f64,
41}
42
43impl Dual {
44    /// Create a new dual number
45    pub fn new(real: f64, dual: f64) -> Self {
46        Self { real, dual }
47    }
48
49    /// Create a constant (no derivative)
50    pub fn constant(value: f64) -> Self {
51        Self {
52            real: value,
53            dual: 0.0,
54        }
55    }
56
57    /// Create a variable (unit derivative)
58    pub fn variable(value: f64) -> Self {
59        Self {
60            real: value,
61            dual: 1.0,
62        }
63    }
64}
65
66// Arithmetic operations for dual numbers
67impl std::ops::Add for Dual {
68    type Output = Self;
69
70    fn add(self, other: Self) -> Self {
71        Self {
72            real: self.real + other.real,
73            dual: self.dual + other.dual,
74        }
75    }
76}
77
78impl std::ops::Sub for Dual {
79    type Output = Self;
80
81    fn sub(self, other: Self) -> Self {
82        Self {
83            real: self.real - other.real,
84            dual: self.dual - other.dual,
85        }
86    }
87}
88
89impl std::ops::Mul for Dual {
90    type Output = Self;
91
92    fn mul(self, other: Self) -> Self {
93        Self {
94            real: self.real * other.real,
95            dual: self.real * other.dual + self.dual * other.real,
96        }
97    }
98}
99
100impl std::ops::Div for Dual {
101    type Output = Self;
102
103    fn div(self, other: Self) -> Self {
104        Self {
105            real: self.real / other.real,
106            dual: (self.dual * other.real - self.real * other.dual) / (other.real * other.real),
107        }
108    }
109}
110
111// Trigonometric functions for dual numbers
112impl Dual {
113    pub fn sin(self) -> Self {
114        Self {
115            real: self.real.sin(),
116            dual: self.dual * self.real.cos(),
117        }
118    }
119
120    pub fn cos(self) -> Self {
121        Self {
122            real: self.real.cos(),
123            dual: -self.dual * self.real.sin(),
124        }
125    }
126
127    pub fn exp(self) -> Self {
128        let exp_real = self.real.exp();
129        Self {
130            real: exp_real,
131            dual: self.dual * exp_real,
132        }
133    }
134
135    pub fn sqrt(self) -> Self {
136        let sqrt_real = self.real.sqrt();
137        Self {
138            real: sqrt_real,
139            dual: self.dual / (2.0 * sqrt_real),
140        }
141    }
142}
143
144/// Computation graph node for reverse-mode autodiff
145#[derive(Debug, Clone)]
146pub struct Node {
147    /// Node identifier
148    pub id: usize,
149    /// Value at this node
150    pub value: Complex<f64>,
151    /// Gradient accumulated at this node
152    pub grad: Complex<f64>,
153    /// Operation that produced this node
154    pub op: Operation,
155    /// Parent nodes
156    pub parents: Vec<usize>,
157}
158
159/// Operations in the computation graph
160#[derive(Debug, Clone)]
161pub enum Operation {
162    /// Input parameter
163    Parameter(String),
164    /// Constant value
165    Constant,
166    /// Addition
167    Add,
168    /// Multiplication
169    Mul,
170    /// Complex conjugate
171    Conj,
172    /// Matrix multiplication
173    MatMul,
174    /// Exponential of imaginary number
175    ExpI,
176}
177
178/// Computation graph for reverse-mode autodiff
179#[derive(Debug)]
180pub struct ComputationGraph {
181    /// Nodes in the graph
182    nodes: Vec<Node>,
183    /// Parameter name to node ID mapping
184    params: FxHashMap<String, usize>,
185    /// Next available node ID
186    next_id: usize,
187}
188
189impl ComputationGraph {
190    /// Create a new computation graph
191    pub fn new() -> Self {
192        Self {
193            nodes: Vec::new(),
194            params: FxHashMap::default(),
195            next_id: 0,
196        }
197    }
198
199    /// Add a parameter node
200    pub fn parameter(&mut self, name: String, value: f64) -> usize {
201        let id = self.next_id;
202        self.next_id += 1;
203
204        let node = Node {
205            id,
206            value: Complex::new(value, 0.0),
207            grad: Complex::new(0.0, 0.0),
208            op: Operation::Parameter(name.clone()),
209            parents: vec![],
210        };
211
212        self.nodes.push(node);
213        self.params.insert(name, id);
214        id
215    }
216
217    /// Add a constant node
218    pub fn constant(&mut self, value: Complex<f64>) -> usize {
219        let id = self.next_id;
220        self.next_id += 1;
221
222        let node = Node {
223            id,
224            value,
225            grad: Complex::new(0.0, 0.0),
226            op: Operation::Constant,
227            parents: vec![],
228        };
229
230        self.nodes.push(node);
231        id
232    }
233
234    /// Add two nodes
235    pub fn add(&mut self, a: usize, b: usize) -> usize {
236        let id = self.next_id;
237        self.next_id += 1;
238
239        let value = self.nodes[a].value + self.nodes[b].value;
240
241        let node = Node {
242            id,
243            value,
244            grad: Complex::new(0.0, 0.0),
245            op: Operation::Add,
246            parents: vec![a, b],
247        };
248
249        self.nodes.push(node);
250        id
251    }
252
253    /// Multiply two nodes
254    pub fn mul(&mut self, a: usize, b: usize) -> usize {
255        let id = self.next_id;
256        self.next_id += 1;
257
258        let value = self.nodes[a].value * self.nodes[b].value;
259
260        let node = Node {
261            id,
262            value,
263            grad: Complex::new(0.0, 0.0),
264            op: Operation::Mul,
265            parents: vec![a, b],
266        };
267
268        self.nodes.push(node);
269        id
270    }
271
272    /// Exponential of i times a real parameter
273    pub fn exp_i(&mut self, theta: usize) -> usize {
274        let id = self.next_id;
275        self.next_id += 1;
276
277        let theta_val = self.nodes[theta].value.re;
278        let value = Complex::new(theta_val.cos(), theta_val.sin());
279
280        let node = Node {
281            id,
282            value,
283            grad: Complex::new(0.0, 0.0),
284            op: Operation::ExpI,
285            parents: vec![theta],
286        };
287
288        self.nodes.push(node);
289        id
290    }
291
292    /// Backward pass to compute gradients
293    pub fn backward(&mut self, output: usize) {
294        // Initialize output gradient
295        self.nodes[output].grad = Complex::new(1.0, 0.0);
296
297        // Traverse in reverse topological order
298        for i in (0..=output).rev() {
299            let grad = self.nodes[i].grad;
300            let parents = self.nodes[i].parents.clone();
301            let op = self.nodes[i].op.clone();
302
303            match op {
304                Operation::Add => {
305                    // d/da (a + b) = 1, d/db (a + b) = 1
306                    if !parents.is_empty() {
307                        self.nodes[parents[0]].grad += grad;
308                        self.nodes[parents[1]].grad += grad;
309                    }
310                }
311                Operation::Mul => {
312                    // d/da (a * b) = b, d/db (a * b) = a
313                    if !parents.is_empty() {
314                        let a = parents[0];
315                        let b = parents[1];
316                        let b_value = self.nodes[b].value;
317                        let a_value = self.nodes[a].value;
318                        self.nodes[a].grad += grad * b_value;
319                        self.nodes[b].grad += grad * a_value;
320                    }
321                }
322                Operation::ExpI => {
323                    // d/dθ e^(iθ) = i * e^(iθ)
324                    if !parents.is_empty() {
325                        let theta = parents[0];
326                        let node_value = self.nodes[i].value;
327                        self.nodes[theta].grad += grad * Complex::new(0.0, 1.0) * node_value;
328                    }
329                }
330                _ => {}
331            }
332        }
333    }
334
335    /// Get gradient for a parameter
336    pub fn get_gradient(&self, param: &str) -> Option<f64> {
337        self.params.get(param).map(|&id| self.nodes[id].grad.re)
338    }
339}
340
341/// Variational quantum gate with autodiff support
342#[derive(Clone)]
343pub struct VariationalGate {
344    /// Gate name
345    pub name: String,
346    /// Target qubits
347    pub qubits: Vec<QubitId>,
348    /// Parameter names
349    pub params: Vec<String>,
350    /// Current parameter values
351    pub values: Vec<f64>,
352    /// Gate generator function
353    pub generator: Arc<dyn Fn(&[f64]) -> Array2<Complex<f64>> + Send + Sync>,
354    /// Differentiation mode
355    pub diff_mode: DiffMode,
356}
357
358impl VariationalGate {
359    /// Create a variational rotation gate around X axis
360    pub fn rx(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
361        let generator = Arc::new(|params: &[f64]| {
362            let theta = params[0];
363            let cos_half = (theta / 2.0).cos();
364            let sin_half = (theta / 2.0).sin();
365
366            Array2::from_shape_vec(
367                (2, 2),
368                vec![
369                    Complex::new(cos_half, 0.0),
370                    Complex::new(0.0, -sin_half),
371                    Complex::new(0.0, -sin_half),
372                    Complex::new(cos_half, 0.0),
373                ],
374            )
375            .unwrap()
376        });
377
378        Self {
379            name: format!("RX({})", param_name),
380            qubits: vec![qubit],
381            params: vec![param_name],
382            values: vec![initial_value],
383            generator,
384            diff_mode: DiffMode::ParameterShift,
385        }
386    }
387
388    /// Create a variational rotation gate around Y axis
389    pub fn ry(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
390        let generator = Arc::new(|params: &[f64]| {
391            let theta = params[0];
392            let cos_half = (theta / 2.0).cos();
393            let sin_half = (theta / 2.0).sin();
394
395            Array2::from_shape_vec(
396                (2, 2),
397                vec![
398                    Complex::new(cos_half, 0.0),
399                    Complex::new(-sin_half, 0.0),
400                    Complex::new(sin_half, 0.0),
401                    Complex::new(cos_half, 0.0),
402                ],
403            )
404            .unwrap()
405        });
406
407        Self {
408            name: format!("RY({})", param_name),
409            qubits: vec![qubit],
410            params: vec![param_name],
411            values: vec![initial_value],
412            generator,
413            diff_mode: DiffMode::ParameterShift,
414        }
415    }
416
417    /// Create a variational rotation gate around Z axis
418    pub fn rz(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
419        let generator = Arc::new(|params: &[f64]| {
420            let theta = params[0];
421            let exp_pos = Complex::new((theta / 2.0).cos(), (theta / 2.0).sin());
422            let exp_neg = Complex::new((theta / 2.0).cos(), -(theta / 2.0).sin());
423
424            Array2::from_shape_vec(
425                (2, 2),
426                vec![
427                    exp_neg,
428                    Complex::new(0.0, 0.0),
429                    Complex::new(0.0, 0.0),
430                    exp_pos,
431                ],
432            )
433            .unwrap()
434        });
435
436        Self {
437            name: format!("RZ({})", param_name),
438            qubits: vec![qubit],
439            params: vec![param_name],
440            values: vec![initial_value],
441            generator,
442            diff_mode: DiffMode::ParameterShift,
443        }
444    }
445
446    /// Create a variational controlled rotation gate
447    pub fn cry(control: QubitId, target: QubitId, param_name: String, initial_value: f64) -> Self {
448        let generator = Arc::new(|params: &[f64]| {
449            let theta = params[0];
450            let cos_half = (theta / 2.0).cos();
451            let sin_half = (theta / 2.0).sin();
452
453            let mut matrix = Array2::eye(4).mapv(|x| Complex::new(x, 0.0));
454            // Apply RY to target when control is |1⟩
455            matrix[[2, 2]] = Complex::new(cos_half, 0.0);
456            matrix[[2, 3]] = Complex::new(-sin_half, 0.0);
457            matrix[[3, 2]] = Complex::new(sin_half, 0.0);
458            matrix[[3, 3]] = Complex::new(cos_half, 0.0);
459
460            matrix
461        });
462
463        Self {
464            name: format!("CRY({}, {})", param_name, control.0),
465            qubits: vec![control, target],
466            params: vec![param_name],
467            values: vec![initial_value],
468            generator,
469            diff_mode: DiffMode::ParameterShift,
470        }
471    }
472
473    /// Get current parameter values
474    pub fn get_params(&self) -> &[f64] {
475        &self.values
476    }
477
478    /// Set parameter values
479    pub fn set_params(&mut self, values: Vec<f64>) -> QuantRS2Result<()> {
480        if values.len() != self.params.len() {
481            return Err(QuantRS2Error::InvalidInput(format!(
482                "Expected {} parameters, got {}",
483                self.params.len(),
484                values.len()
485            )));
486        }
487        self.values = values;
488        Ok(())
489    }
490
491    /// Compute gradient with respect to parameters
492    pub fn gradient(
493        &self,
494        loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
495    ) -> QuantRS2Result<Vec<f64>> {
496        match self.diff_mode {
497            DiffMode::ParameterShift => self.parameter_shift_gradient(loss_fn),
498            DiffMode::FiniteDiff { epsilon } => self.finite_diff_gradient(loss_fn, epsilon),
499            DiffMode::Forward => self.forward_mode_gradient(loss_fn),
500            DiffMode::Reverse => self.reverse_mode_gradient(loss_fn),
501        }
502    }
503
504    /// Parameter shift rule for gradient computation
505    fn parameter_shift_gradient(
506        &self,
507        loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
508    ) -> QuantRS2Result<Vec<f64>> {
509        let mut gradients = vec![0.0; self.params.len()];
510
511        for (i, &value) in self.values.iter().enumerate() {
512            // Shift parameter by +π/2
513            let mut params_plus = self.values.clone();
514            params_plus[i] = value + PI / 2.0;
515            let matrix_plus = (self.generator)(&params_plus);
516            let loss_plus = loss_fn(&matrix_plus);
517
518            // Shift parameter by -π/2
519            let mut params_minus = self.values.clone();
520            params_minus[i] = value - PI / 2.0;
521            let matrix_minus = (self.generator)(&params_minus);
522            let loss_minus = loss_fn(&matrix_minus);
523
524            // Gradient via parameter shift rule
525            gradients[i] = (loss_plus - loss_minus) / 2.0;
526        }
527
528        Ok(gradients)
529    }
530
531    /// Finite differences gradient approximation
532    fn finite_diff_gradient(
533        &self,
534        loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
535        epsilon: f64,
536    ) -> QuantRS2Result<Vec<f64>> {
537        let mut gradients = vec![0.0; self.params.len()];
538
539        for (i, &value) in self.values.iter().enumerate() {
540            // Forward difference
541            let mut params_plus = self.values.clone();
542            params_plus[i] = value + epsilon;
543            let matrix_plus = (self.generator)(&params_plus);
544            let loss_plus = loss_fn(&matrix_plus);
545
546            // Current value
547            let matrix = (self.generator)(&self.values);
548            let loss = loss_fn(&matrix);
549
550            // Gradient approximation
551            gradients[i] = (loss_plus - loss) / epsilon;
552        }
553
554        Ok(gradients)
555    }
556
557    /// Forward-mode automatic differentiation
558    fn forward_mode_gradient(
559        &self,
560        loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
561    ) -> QuantRS2Result<Vec<f64>> {
562        // Simplified implementation - would use dual numbers throughout
563        let mut gradients = vec![0.0; self.params.len()];
564
565        // For demonstration, use finite differences as fallback
566        self.finite_diff_gradient(loss_fn, 1e-8)
567    }
568
569    /// Reverse-mode automatic differentiation
570    fn reverse_mode_gradient(
571        &self,
572        loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
573    ) -> QuantRS2Result<Vec<f64>> {
574        // Build computation graph
575        let mut graph = ComputationGraph::new();
576
577        // Add parameters to graph
578        let param_nodes: Vec<_> = self
579            .params
580            .iter()
581            .zip(&self.values)
582            .map(|(name, &value)| graph.parameter(name.clone(), value))
583            .collect();
584
585        // Compute matrix elements using graph
586        // This is simplified - full implementation would build entire matrix computation
587
588        // For now, use parameter shift as fallback
589        self.parameter_shift_gradient(loss_fn)
590    }
591}
592
593impl std::fmt::Debug for VariationalGate {
594    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
595        f.debug_struct("VariationalGate")
596            .field("name", &self.name)
597            .field("qubits", &self.qubits)
598            .field("params", &self.params)
599            .field("values", &self.values)
600            .field("diff_mode", &self.diff_mode)
601            .finish()
602    }
603}
604
605impl GateOp for VariationalGate {
606    fn name(&self) -> &'static str {
607        // We need to leak the string to get a 'static lifetime
608        // This is safe for gate names which are created once
609        Box::leak(self.name.clone().into_boxed_str())
610    }
611
612    fn qubits(&self) -> Vec<QubitId> {
613        self.qubits.clone()
614    }
615
616    fn is_parameterized(&self) -> bool {
617        true
618    }
619
620    fn matrix(&self) -> QuantRS2Result<Vec<Complex<f64>>> {
621        let mat = (self.generator)(&self.values);
622        Ok(mat.iter().cloned().collect())
623    }
624
625    fn as_any(&self) -> &dyn std::any::Any {
626        self
627    }
628
629    fn clone_gate(&self) -> Box<dyn GateOp> {
630        Box::new(self.clone())
631    }
632}
633
634/// Variational quantum circuit with multiple parameterized gates
635#[derive(Debug)]
636pub struct VariationalCircuit {
637    /// List of gates in the circuit
638    pub gates: Vec<VariationalGate>,
639    /// Parameter name to gate indices mapping
640    pub param_map: FxHashMap<String, Vec<usize>>,
641    /// Number of qubits
642    pub num_qubits: usize,
643}
644
645impl VariationalCircuit {
646    /// Create a new variational circuit
647    pub fn new(num_qubits: usize) -> Self {
648        Self {
649            gates: Vec::new(),
650            param_map: FxHashMap::default(),
651            num_qubits,
652        }
653    }
654
655    /// Add a variational gate to the circuit
656    pub fn add_gate(&mut self, gate: VariationalGate) {
657        let gate_idx = self.gates.len();
658
659        // Update parameter map
660        for param in &gate.params {
661            self.param_map
662                .entry(param.clone())
663                .or_insert_with(Vec::new)
664                .push(gate_idx);
665        }
666
667        self.gates.push(gate);
668    }
669
670    /// Get all parameter names
671    pub fn parameter_names(&self) -> Vec<String> {
672        let mut names: Vec<_> = self.param_map.keys().cloned().collect();
673        names.sort();
674        names
675    }
676
677    /// Get current parameter values
678    pub fn get_parameters(&self) -> FxHashMap<String, f64> {
679        let mut params = FxHashMap::default();
680
681        for gate in &self.gates {
682            for (name, &value) in gate.params.iter().zip(&gate.values) {
683                params.insert(name.clone(), value);
684            }
685        }
686
687        params
688    }
689
690    /// Set parameter values
691    pub fn set_parameters(&mut self, params: &FxHashMap<String, f64>) -> QuantRS2Result<()> {
692        for (param_name, &value) in params {
693            if let Some(gate_indices) = self.param_map.get(param_name) {
694                for &idx in gate_indices {
695                    if let Some(param_idx) =
696                        self.gates[idx].params.iter().position(|p| p == param_name)
697                    {
698                        self.gates[idx].values[param_idx] = value;
699                    }
700                }
701            }
702        }
703
704        Ok(())
705    }
706
707    /// Compute gradients for all parameters
708    pub fn compute_gradients(
709        &self,
710        loss_fn: impl Fn(&[VariationalGate]) -> f64,
711    ) -> QuantRS2Result<FxHashMap<String, f64>> {
712        let mut gradients = FxHashMap::default();
713
714        // Use parameter shift rule for each parameter
715        for param_name in self.parameter_names() {
716            let grad = self.parameter_gradient(param_name.as_str(), &loss_fn)?;
717            gradients.insert(param_name, grad);
718        }
719
720        Ok(gradients)
721    }
722
723    /// Compute gradient for a single parameter
724    fn parameter_gradient(
725        &self,
726        param_name: &str,
727        loss_fn: &impl Fn(&[VariationalGate]) -> f64,
728    ) -> QuantRS2Result<f64> {
729        let current_params = self.get_parameters();
730        let current_value = *current_params.get(param_name).ok_or_else(|| {
731            QuantRS2Error::InvalidInput(format!("Parameter {} not found", param_name))
732        })?;
733
734        // Create circuit copies with shifted parameters
735        let mut circuit_plus = self.clone_circuit();
736        let mut params_plus = current_params.clone();
737        params_plus.insert(param_name.to_string(), current_value + PI / 2.0);
738        circuit_plus.set_parameters(&params_plus)?;
739
740        let mut circuit_minus = self.clone_circuit();
741        let mut params_minus = current_params.clone();
742        params_minus.insert(param_name.to_string(), current_value - PI / 2.0);
743        circuit_minus.set_parameters(&params_minus)?;
744
745        // Compute gradient via parameter shift
746        let loss_plus = loss_fn(&circuit_plus.gates);
747        let loss_minus = loss_fn(&circuit_minus.gates);
748
749        Ok((loss_plus - loss_minus) / 2.0)
750    }
751
752    /// Clone the circuit structure
753    fn clone_circuit(&self) -> Self {
754        Self {
755            gates: self.gates.clone(),
756            param_map: self.param_map.clone(),
757            num_qubits: self.num_qubits,
758        }
759    }
760}
761
762/// Gradient-based optimizer for variational circuits
763#[derive(Debug, Clone)]
764pub struct VariationalOptimizer {
765    /// Learning rate
766    pub learning_rate: f64,
767    /// Momentum coefficient
768    pub momentum: f64,
769    /// Accumulated momentum
770    velocities: FxHashMap<String, f64>,
771}
772
773impl VariationalOptimizer {
774    /// Create a new optimizer
775    pub fn new(learning_rate: f64, momentum: f64) -> Self {
776        Self {
777            learning_rate,
778            momentum,
779            velocities: FxHashMap::default(),
780        }
781    }
782
783    /// Perform one optimization step
784    pub fn step(
785        &mut self,
786        circuit: &mut VariationalCircuit,
787        gradients: &FxHashMap<String, f64>,
788    ) -> QuantRS2Result<()> {
789        let mut new_params = circuit.get_parameters();
790
791        for (param_name, &grad) in gradients {
792            // Update velocity with momentum
793            let velocity = self.velocities.entry(param_name.clone()).or_insert(0.0);
794            *velocity = self.momentum * *velocity - self.learning_rate * grad;
795
796            // Update parameter
797            if let Some(value) = new_params.get_mut(param_name) {
798                *value += *velocity;
799            }
800        }
801
802        circuit.set_parameters(&new_params)
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809
810    #[test]
811    fn test_dual_arithmetic() {
812        let a = Dual::variable(2.0);
813        let b = Dual::constant(3.0);
814
815        let c = a + b;
816        assert_eq!(c.real, 5.0);
817        assert_eq!(c.dual, 1.0);
818
819        let d = a * b;
820        assert_eq!(d.real, 6.0);
821        assert_eq!(d.dual, 3.0);
822
823        let e = a.sin();
824        assert!((e.real - 2.0_f64.sin()).abs() < 1e-10);
825        assert!((e.dual - 2.0_f64.cos()).abs() < 1e-10);
826    }
827
828    #[test]
829    fn test_variational_rx_gate() {
830        let gate = VariationalGate::rx(QubitId(0), "theta".to_string(), PI / 4.0);
831
832        let matrix_vec = gate.matrix().unwrap();
833        assert_eq!(matrix_vec.len(), 4);
834
835        // Convert to Array2 for unitary check
836        let matrix = Array2::from_shape_vec((2, 2), matrix_vec).unwrap();
837        let mat = DenseMatrix::new(matrix).unwrap();
838        assert!(mat.is_unitary(1e-10).unwrap());
839    }
840
841    #[test]
842    fn test_parameter_shift_gradient() {
843        // Use a specific angle
844        let theta = PI / 3.0;
845        let gate = VariationalGate::ry(QubitId(0), "phi".to_string(), theta);
846
847        // Simple loss function: expectation value of Z
848        let loss_fn = |matrix: &Array2<Complex<f64>>| -> f64 {
849            // For |0⟩ state, <Z> = matrix[0,0] - matrix[1,1]
850            // But we're using trace for simplicity
851            (matrix[[0, 0]] + matrix[[1, 1]]).re
852        };
853
854        let gradients = gate.gradient(loss_fn).unwrap();
855        assert_eq!(gradients.len(), 1);
856
857        // For RY(θ), the matrix trace is 2*cos(θ/2)
858        // Using parameter shift rule with shifts of ±π/2:
859        // gradient = [f(θ+π/2) - f(θ-π/2)] / 2
860        // = [2*cos((θ+π/2)/2) - 2*cos((θ-π/2)/2)] / 2
861        // = cos(θ/2 + π/4) - cos(θ/2 - π/4)
862        let plus_shift = 2.0 * ((theta + PI / 2.0) / 2.0).cos();
863        let minus_shift = 2.0 * ((theta - PI / 2.0) / 2.0).cos();
864        let expected = (plus_shift - minus_shift) / 2.0;
865
866        // Allow for numerical precision
867        assert!(
868            (gradients[0] - expected).abs() < 1e-5,
869            "Expected gradient: {}, got: {}",
870            expected,
871            gradients[0]
872        );
873    }
874
875    #[test]
876    fn test_variational_circuit() {
877        let mut circuit = VariationalCircuit::new(2);
878
879        circuit.add_gate(VariationalGate::rx(QubitId(0), "theta1".to_string(), 0.1));
880        circuit.add_gate(VariationalGate::ry(QubitId(1), "theta2".to_string(), 0.2));
881        circuit.add_gate(VariationalGate::cry(
882            QubitId(0),
883            QubitId(1),
884            "theta3".to_string(),
885            0.3,
886        ));
887
888        assert_eq!(circuit.gates.len(), 3);
889        assert_eq!(circuit.parameter_names().len(), 3);
890
891        // Update parameters
892        let mut new_params = FxHashMap::default();
893        new_params.insert("theta1".to_string(), 0.5);
894        new_params.insert("theta2".to_string(), 0.6);
895        new_params.insert("theta3".to_string(), 0.7);
896
897        circuit.set_parameters(&new_params).unwrap();
898
899        let params = circuit.get_parameters();
900        assert_eq!(params.get("theta1"), Some(&0.5));
901        assert_eq!(params.get("theta2"), Some(&0.6));
902        assert_eq!(params.get("theta3"), Some(&0.7));
903    }
904
905    #[test]
906    fn test_optimizer() {
907        let mut circuit = VariationalCircuit::new(1);
908        circuit.add_gate(VariationalGate::rx(QubitId(0), "theta".to_string(), 0.0));
909
910        let mut optimizer = VariationalOptimizer::new(0.1, 0.9);
911
912        // Dummy gradients
913        let mut gradients = FxHashMap::default();
914        gradients.insert("theta".to_string(), 1.0);
915
916        // Take optimization step
917        optimizer.step(&mut circuit, &gradients).unwrap();
918
919        let params = circuit.get_parameters();
920        assert!(params.get("theta").unwrap().abs() > 0.0);
921    }
922}