1use crate::{
8 error::{QuantRS2Error, QuantRS2Result},
9 gate::GateOp,
10 matrix_ops::{DenseMatrix, QuantumMatrix},
11 qubit::QubitId,
12 register::Register,
13};
14use ndarray::{Array1, Array2};
15use num_complex::Complex;
16use rustc_hash::FxHashMap;
17use std::any::Any;
18use std::f64::consts::PI;
19use std::sync::Arc;
20
21#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum DiffMode {
24 Forward,
26 Reverse,
28 ParameterShift,
30 FiniteDiff { epsilon: f64 },
32}
33
34#[derive(Debug, Clone, Copy)]
36pub struct Dual {
37 pub real: f64,
39 pub dual: f64,
41}
42
43impl Dual {
44 pub fn new(real: f64, dual: f64) -> Self {
46 Self { real, dual }
47 }
48
49 pub fn constant(value: f64) -> Self {
51 Self {
52 real: value,
53 dual: 0.0,
54 }
55 }
56
57 pub fn variable(value: f64) -> Self {
59 Self {
60 real: value,
61 dual: 1.0,
62 }
63 }
64}
65
66impl std::ops::Add for Dual {
68 type Output = Self;
69
70 fn add(self, other: Self) -> Self {
71 Self {
72 real: self.real + other.real,
73 dual: self.dual + other.dual,
74 }
75 }
76}
77
78impl std::ops::Sub for Dual {
79 type Output = Self;
80
81 fn sub(self, other: Self) -> Self {
82 Self {
83 real: self.real - other.real,
84 dual: self.dual - other.dual,
85 }
86 }
87}
88
89impl std::ops::Mul for Dual {
90 type Output = Self;
91
92 fn mul(self, other: Self) -> Self {
93 Self {
94 real: self.real * other.real,
95 dual: self.real * other.dual + self.dual * other.real,
96 }
97 }
98}
99
100impl std::ops::Div for Dual {
101 type Output = Self;
102
103 fn div(self, other: Self) -> Self {
104 Self {
105 real: self.real / other.real,
106 dual: (self.dual * other.real - self.real * other.dual) / (other.real * other.real),
107 }
108 }
109}
110
111impl Dual {
113 pub fn sin(self) -> Self {
114 Self {
115 real: self.real.sin(),
116 dual: self.dual * self.real.cos(),
117 }
118 }
119
120 pub fn cos(self) -> Self {
121 Self {
122 real: self.real.cos(),
123 dual: -self.dual * self.real.sin(),
124 }
125 }
126
127 pub fn exp(self) -> Self {
128 let exp_real = self.real.exp();
129 Self {
130 real: exp_real,
131 dual: self.dual * exp_real,
132 }
133 }
134
135 pub fn sqrt(self) -> Self {
136 let sqrt_real = self.real.sqrt();
137 Self {
138 real: sqrt_real,
139 dual: self.dual / (2.0 * sqrt_real),
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct Node {
147 pub id: usize,
149 pub value: Complex<f64>,
151 pub grad: Complex<f64>,
153 pub op: Operation,
155 pub parents: Vec<usize>,
157}
158
159#[derive(Debug, Clone)]
161pub enum Operation {
162 Parameter(String),
164 Constant,
166 Add,
168 Mul,
170 Conj,
172 MatMul,
174 ExpI,
176}
177
178#[derive(Debug)]
180pub struct ComputationGraph {
181 nodes: Vec<Node>,
183 params: FxHashMap<String, usize>,
185 next_id: usize,
187}
188
189impl ComputationGraph {
190 pub fn new() -> Self {
192 Self {
193 nodes: Vec::new(),
194 params: FxHashMap::default(),
195 next_id: 0,
196 }
197 }
198
199 pub fn parameter(&mut self, name: String, value: f64) -> usize {
201 let id = self.next_id;
202 self.next_id += 1;
203
204 let node = Node {
205 id,
206 value: Complex::new(value, 0.0),
207 grad: Complex::new(0.0, 0.0),
208 op: Operation::Parameter(name.clone()),
209 parents: vec![],
210 };
211
212 self.nodes.push(node);
213 self.params.insert(name, id);
214 id
215 }
216
217 pub fn constant(&mut self, value: Complex<f64>) -> usize {
219 let id = self.next_id;
220 self.next_id += 1;
221
222 let node = Node {
223 id,
224 value,
225 grad: Complex::new(0.0, 0.0),
226 op: Operation::Constant,
227 parents: vec![],
228 };
229
230 self.nodes.push(node);
231 id
232 }
233
234 pub fn add(&mut self, a: usize, b: usize) -> usize {
236 let id = self.next_id;
237 self.next_id += 1;
238
239 let value = self.nodes[a].value + self.nodes[b].value;
240
241 let node = Node {
242 id,
243 value,
244 grad: Complex::new(0.0, 0.0),
245 op: Operation::Add,
246 parents: vec![a, b],
247 };
248
249 self.nodes.push(node);
250 id
251 }
252
253 pub fn mul(&mut self, a: usize, b: usize) -> usize {
255 let id = self.next_id;
256 self.next_id += 1;
257
258 let value = self.nodes[a].value * self.nodes[b].value;
259
260 let node = Node {
261 id,
262 value,
263 grad: Complex::new(0.0, 0.0),
264 op: Operation::Mul,
265 parents: vec![a, b],
266 };
267
268 self.nodes.push(node);
269 id
270 }
271
272 pub fn exp_i(&mut self, theta: usize) -> usize {
274 let id = self.next_id;
275 self.next_id += 1;
276
277 let theta_val = self.nodes[theta].value.re;
278 let value = Complex::new(theta_val.cos(), theta_val.sin());
279
280 let node = Node {
281 id,
282 value,
283 grad: Complex::new(0.0, 0.0),
284 op: Operation::ExpI,
285 parents: vec![theta],
286 };
287
288 self.nodes.push(node);
289 id
290 }
291
292 pub fn backward(&mut self, output: usize) {
294 self.nodes[output].grad = Complex::new(1.0, 0.0);
296
297 for i in (0..=output).rev() {
299 let grad = self.nodes[i].grad;
300 let parents = self.nodes[i].parents.clone();
301 let op = self.nodes[i].op.clone();
302
303 match op {
304 Operation::Add => {
305 if !parents.is_empty() {
307 self.nodes[parents[0]].grad += grad;
308 self.nodes[parents[1]].grad += grad;
309 }
310 }
311 Operation::Mul => {
312 if !parents.is_empty() {
314 let a = parents[0];
315 let b = parents[1];
316 let b_value = self.nodes[b].value;
317 let a_value = self.nodes[a].value;
318 self.nodes[a].grad += grad * b_value;
319 self.nodes[b].grad += grad * a_value;
320 }
321 }
322 Operation::ExpI => {
323 if !parents.is_empty() {
325 let theta = parents[0];
326 let node_value = self.nodes[i].value;
327 self.nodes[theta].grad += grad * Complex::new(0.0, 1.0) * node_value;
328 }
329 }
330 _ => {}
331 }
332 }
333 }
334
335 pub fn get_gradient(&self, param: &str) -> Option<f64> {
337 self.params.get(param).map(|&id| self.nodes[id].grad.re)
338 }
339}
340
341#[derive(Clone)]
343pub struct VariationalGate {
344 pub name: String,
346 pub qubits: Vec<QubitId>,
348 pub params: Vec<String>,
350 pub values: Vec<f64>,
352 pub generator: Arc<dyn Fn(&[f64]) -> Array2<Complex<f64>> + Send + Sync>,
354 pub diff_mode: DiffMode,
356}
357
358impl VariationalGate {
359 pub fn rx(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
361 let generator = Arc::new(|params: &[f64]| {
362 let theta = params[0];
363 let cos_half = (theta / 2.0).cos();
364 let sin_half = (theta / 2.0).sin();
365
366 Array2::from_shape_vec(
367 (2, 2),
368 vec![
369 Complex::new(cos_half, 0.0),
370 Complex::new(0.0, -sin_half),
371 Complex::new(0.0, -sin_half),
372 Complex::new(cos_half, 0.0),
373 ],
374 )
375 .unwrap()
376 });
377
378 Self {
379 name: format!("RX({})", param_name),
380 qubits: vec![qubit],
381 params: vec![param_name],
382 values: vec![initial_value],
383 generator,
384 diff_mode: DiffMode::ParameterShift,
385 }
386 }
387
388 pub fn ry(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
390 let generator = Arc::new(|params: &[f64]| {
391 let theta = params[0];
392 let cos_half = (theta / 2.0).cos();
393 let sin_half = (theta / 2.0).sin();
394
395 Array2::from_shape_vec(
396 (2, 2),
397 vec![
398 Complex::new(cos_half, 0.0),
399 Complex::new(-sin_half, 0.0),
400 Complex::new(sin_half, 0.0),
401 Complex::new(cos_half, 0.0),
402 ],
403 )
404 .unwrap()
405 });
406
407 Self {
408 name: format!("RY({})", param_name),
409 qubits: vec![qubit],
410 params: vec![param_name],
411 values: vec![initial_value],
412 generator,
413 diff_mode: DiffMode::ParameterShift,
414 }
415 }
416
417 pub fn rz(qubit: QubitId, param_name: String, initial_value: f64) -> Self {
419 let generator = Arc::new(|params: &[f64]| {
420 let theta = params[0];
421 let exp_pos = Complex::new((theta / 2.0).cos(), (theta / 2.0).sin());
422 let exp_neg = Complex::new((theta / 2.0).cos(), -(theta / 2.0).sin());
423
424 Array2::from_shape_vec(
425 (2, 2),
426 vec![
427 exp_neg,
428 Complex::new(0.0, 0.0),
429 Complex::new(0.0, 0.0),
430 exp_pos,
431 ],
432 )
433 .unwrap()
434 });
435
436 Self {
437 name: format!("RZ({})", param_name),
438 qubits: vec![qubit],
439 params: vec![param_name],
440 values: vec![initial_value],
441 generator,
442 diff_mode: DiffMode::ParameterShift,
443 }
444 }
445
446 pub fn cry(control: QubitId, target: QubitId, param_name: String, initial_value: f64) -> Self {
448 let generator = Arc::new(|params: &[f64]| {
449 let theta = params[0];
450 let cos_half = (theta / 2.0).cos();
451 let sin_half = (theta / 2.0).sin();
452
453 let mut matrix = Array2::eye(4).mapv(|x| Complex::new(x, 0.0));
454 matrix[[2, 2]] = Complex::new(cos_half, 0.0);
456 matrix[[2, 3]] = Complex::new(-sin_half, 0.0);
457 matrix[[3, 2]] = Complex::new(sin_half, 0.0);
458 matrix[[3, 3]] = Complex::new(cos_half, 0.0);
459
460 matrix
461 });
462
463 Self {
464 name: format!("CRY({}, {})", param_name, control.0),
465 qubits: vec![control, target],
466 params: vec![param_name],
467 values: vec![initial_value],
468 generator,
469 diff_mode: DiffMode::ParameterShift,
470 }
471 }
472
473 pub fn get_params(&self) -> &[f64] {
475 &self.values
476 }
477
478 pub fn set_params(&mut self, values: Vec<f64>) -> QuantRS2Result<()> {
480 if values.len() != self.params.len() {
481 return Err(QuantRS2Error::InvalidInput(format!(
482 "Expected {} parameters, got {}",
483 self.params.len(),
484 values.len()
485 )));
486 }
487 self.values = values;
488 Ok(())
489 }
490
491 pub fn gradient(
493 &self,
494 loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
495 ) -> QuantRS2Result<Vec<f64>> {
496 match self.diff_mode {
497 DiffMode::ParameterShift => self.parameter_shift_gradient(loss_fn),
498 DiffMode::FiniteDiff { epsilon } => self.finite_diff_gradient(loss_fn, epsilon),
499 DiffMode::Forward => self.forward_mode_gradient(loss_fn),
500 DiffMode::Reverse => self.reverse_mode_gradient(loss_fn),
501 }
502 }
503
504 fn parameter_shift_gradient(
506 &self,
507 loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
508 ) -> QuantRS2Result<Vec<f64>> {
509 let mut gradients = vec![0.0; self.params.len()];
510
511 for (i, &value) in self.values.iter().enumerate() {
512 let mut params_plus = self.values.clone();
514 params_plus[i] = value + PI / 2.0;
515 let matrix_plus = (self.generator)(¶ms_plus);
516 let loss_plus = loss_fn(&matrix_plus);
517
518 let mut params_minus = self.values.clone();
520 params_minus[i] = value - PI / 2.0;
521 let matrix_minus = (self.generator)(¶ms_minus);
522 let loss_minus = loss_fn(&matrix_minus);
523
524 gradients[i] = (loss_plus - loss_minus) / 2.0;
526 }
527
528 Ok(gradients)
529 }
530
531 fn finite_diff_gradient(
533 &self,
534 loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
535 epsilon: f64,
536 ) -> QuantRS2Result<Vec<f64>> {
537 let mut gradients = vec![0.0; self.params.len()];
538
539 for (i, &value) in self.values.iter().enumerate() {
540 let mut params_plus = self.values.clone();
542 params_plus[i] = value + epsilon;
543 let matrix_plus = (self.generator)(¶ms_plus);
544 let loss_plus = loss_fn(&matrix_plus);
545
546 let matrix = (self.generator)(&self.values);
548 let loss = loss_fn(&matrix);
549
550 gradients[i] = (loss_plus - loss) / epsilon;
552 }
553
554 Ok(gradients)
555 }
556
557 fn forward_mode_gradient(
559 &self,
560 loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
561 ) -> QuantRS2Result<Vec<f64>> {
562 let mut gradients = vec![0.0; self.params.len()];
564
565 self.finite_diff_gradient(loss_fn, 1e-8)
567 }
568
569 fn reverse_mode_gradient(
571 &self,
572 loss_fn: impl Fn(&Array2<Complex<f64>>) -> f64,
573 ) -> QuantRS2Result<Vec<f64>> {
574 let mut graph = ComputationGraph::new();
576
577 let param_nodes: Vec<_> = self
579 .params
580 .iter()
581 .zip(&self.values)
582 .map(|(name, &value)| graph.parameter(name.clone(), value))
583 .collect();
584
585 self.parameter_shift_gradient(loss_fn)
590 }
591}
592
593impl std::fmt::Debug for VariationalGate {
594 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
595 f.debug_struct("VariationalGate")
596 .field("name", &self.name)
597 .field("qubits", &self.qubits)
598 .field("params", &self.params)
599 .field("values", &self.values)
600 .field("diff_mode", &self.diff_mode)
601 .finish()
602 }
603}
604
605impl GateOp for VariationalGate {
606 fn name(&self) -> &'static str {
607 Box::leak(self.name.clone().into_boxed_str())
610 }
611
612 fn qubits(&self) -> Vec<QubitId> {
613 self.qubits.clone()
614 }
615
616 fn is_parameterized(&self) -> bool {
617 true
618 }
619
620 fn matrix(&self) -> QuantRS2Result<Vec<Complex<f64>>> {
621 let mat = (self.generator)(&self.values);
622 Ok(mat.iter().cloned().collect())
623 }
624
625 fn as_any(&self) -> &dyn std::any::Any {
626 self
627 }
628
629 fn clone_gate(&self) -> Box<dyn GateOp> {
630 Box::new(self.clone())
631 }
632}
633
634#[derive(Debug)]
636pub struct VariationalCircuit {
637 pub gates: Vec<VariationalGate>,
639 pub param_map: FxHashMap<String, Vec<usize>>,
641 pub num_qubits: usize,
643}
644
645impl VariationalCircuit {
646 pub fn new(num_qubits: usize) -> Self {
648 Self {
649 gates: Vec::new(),
650 param_map: FxHashMap::default(),
651 num_qubits,
652 }
653 }
654
655 pub fn add_gate(&mut self, gate: VariationalGate) {
657 let gate_idx = self.gates.len();
658
659 for param in &gate.params {
661 self.param_map
662 .entry(param.clone())
663 .or_insert_with(Vec::new)
664 .push(gate_idx);
665 }
666
667 self.gates.push(gate);
668 }
669
670 pub fn parameter_names(&self) -> Vec<String> {
672 let mut names: Vec<_> = self.param_map.keys().cloned().collect();
673 names.sort();
674 names
675 }
676
677 pub fn get_parameters(&self) -> FxHashMap<String, f64> {
679 let mut params = FxHashMap::default();
680
681 for gate in &self.gates {
682 for (name, &value) in gate.params.iter().zip(&gate.values) {
683 params.insert(name.clone(), value);
684 }
685 }
686
687 params
688 }
689
690 pub fn set_parameters(&mut self, params: &FxHashMap<String, f64>) -> QuantRS2Result<()> {
692 for (param_name, &value) in params {
693 if let Some(gate_indices) = self.param_map.get(param_name) {
694 for &idx in gate_indices {
695 if let Some(param_idx) =
696 self.gates[idx].params.iter().position(|p| p == param_name)
697 {
698 self.gates[idx].values[param_idx] = value;
699 }
700 }
701 }
702 }
703
704 Ok(())
705 }
706
707 pub fn compute_gradients(
709 &self,
710 loss_fn: impl Fn(&[VariationalGate]) -> f64,
711 ) -> QuantRS2Result<FxHashMap<String, f64>> {
712 let mut gradients = FxHashMap::default();
713
714 for param_name in self.parameter_names() {
716 let grad = self.parameter_gradient(param_name.as_str(), &loss_fn)?;
717 gradients.insert(param_name, grad);
718 }
719
720 Ok(gradients)
721 }
722
723 fn parameter_gradient(
725 &self,
726 param_name: &str,
727 loss_fn: &impl Fn(&[VariationalGate]) -> f64,
728 ) -> QuantRS2Result<f64> {
729 let current_params = self.get_parameters();
730 let current_value = *current_params.get(param_name).ok_or_else(|| {
731 QuantRS2Error::InvalidInput(format!("Parameter {} not found", param_name))
732 })?;
733
734 let mut circuit_plus = self.clone_circuit();
736 let mut params_plus = current_params.clone();
737 params_plus.insert(param_name.to_string(), current_value + PI / 2.0);
738 circuit_plus.set_parameters(¶ms_plus)?;
739
740 let mut circuit_minus = self.clone_circuit();
741 let mut params_minus = current_params.clone();
742 params_minus.insert(param_name.to_string(), current_value - PI / 2.0);
743 circuit_minus.set_parameters(¶ms_minus)?;
744
745 let loss_plus = loss_fn(&circuit_plus.gates);
747 let loss_minus = loss_fn(&circuit_minus.gates);
748
749 Ok((loss_plus - loss_minus) / 2.0)
750 }
751
752 fn clone_circuit(&self) -> Self {
754 Self {
755 gates: self.gates.clone(),
756 param_map: self.param_map.clone(),
757 num_qubits: self.num_qubits,
758 }
759 }
760}
761
762#[derive(Debug, Clone)]
764pub struct VariationalOptimizer {
765 pub learning_rate: f64,
767 pub momentum: f64,
769 velocities: FxHashMap<String, f64>,
771}
772
773impl VariationalOptimizer {
774 pub fn new(learning_rate: f64, momentum: f64) -> Self {
776 Self {
777 learning_rate,
778 momentum,
779 velocities: FxHashMap::default(),
780 }
781 }
782
783 pub fn step(
785 &mut self,
786 circuit: &mut VariationalCircuit,
787 gradients: &FxHashMap<String, f64>,
788 ) -> QuantRS2Result<()> {
789 let mut new_params = circuit.get_parameters();
790
791 for (param_name, &grad) in gradients {
792 let velocity = self.velocities.entry(param_name.clone()).or_insert(0.0);
794 *velocity = self.momentum * *velocity - self.learning_rate * grad;
795
796 if let Some(value) = new_params.get_mut(param_name) {
798 *value += *velocity;
799 }
800 }
801
802 circuit.set_parameters(&new_params)
803 }
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809
810 #[test]
811 fn test_dual_arithmetic() {
812 let a = Dual::variable(2.0);
813 let b = Dual::constant(3.0);
814
815 let c = a + b;
816 assert_eq!(c.real, 5.0);
817 assert_eq!(c.dual, 1.0);
818
819 let d = a * b;
820 assert_eq!(d.real, 6.0);
821 assert_eq!(d.dual, 3.0);
822
823 let e = a.sin();
824 assert!((e.real - 2.0_f64.sin()).abs() < 1e-10);
825 assert!((e.dual - 2.0_f64.cos()).abs() < 1e-10);
826 }
827
828 #[test]
829 fn test_variational_rx_gate() {
830 let gate = VariationalGate::rx(QubitId(0), "theta".to_string(), PI / 4.0);
831
832 let matrix_vec = gate.matrix().unwrap();
833 assert_eq!(matrix_vec.len(), 4);
834
835 let matrix = Array2::from_shape_vec((2, 2), matrix_vec).unwrap();
837 let mat = DenseMatrix::new(matrix).unwrap();
838 assert!(mat.is_unitary(1e-10).unwrap());
839 }
840
841 #[test]
842 fn test_parameter_shift_gradient() {
843 let theta = PI / 3.0;
845 let gate = VariationalGate::ry(QubitId(0), "phi".to_string(), theta);
846
847 let loss_fn = |matrix: &Array2<Complex<f64>>| -> f64 {
849 (matrix[[0, 0]] + matrix[[1, 1]]).re
852 };
853
854 let gradients = gate.gradient(loss_fn).unwrap();
855 assert_eq!(gradients.len(), 1);
856
857 let plus_shift = 2.0 * ((theta + PI / 2.0) / 2.0).cos();
863 let minus_shift = 2.0 * ((theta - PI / 2.0) / 2.0).cos();
864 let expected = (plus_shift - minus_shift) / 2.0;
865
866 assert!(
868 (gradients[0] - expected).abs() < 1e-5,
869 "Expected gradient: {}, got: {}",
870 expected,
871 gradients[0]
872 );
873 }
874
875 #[test]
876 fn test_variational_circuit() {
877 let mut circuit = VariationalCircuit::new(2);
878
879 circuit.add_gate(VariationalGate::rx(QubitId(0), "theta1".to_string(), 0.1));
880 circuit.add_gate(VariationalGate::ry(QubitId(1), "theta2".to_string(), 0.2));
881 circuit.add_gate(VariationalGate::cry(
882 QubitId(0),
883 QubitId(1),
884 "theta3".to_string(),
885 0.3,
886 ));
887
888 assert_eq!(circuit.gates.len(), 3);
889 assert_eq!(circuit.parameter_names().len(), 3);
890
891 let mut new_params = FxHashMap::default();
893 new_params.insert("theta1".to_string(), 0.5);
894 new_params.insert("theta2".to_string(), 0.6);
895 new_params.insert("theta3".to_string(), 0.7);
896
897 circuit.set_parameters(&new_params).unwrap();
898
899 let params = circuit.get_parameters();
900 assert_eq!(params.get("theta1"), Some(&0.5));
901 assert_eq!(params.get("theta2"), Some(&0.6));
902 assert_eq!(params.get("theta3"), Some(&0.7));
903 }
904
905 #[test]
906 fn test_optimizer() {
907 let mut circuit = VariationalCircuit::new(1);
908 circuit.add_gate(VariationalGate::rx(QubitId(0), "theta".to_string(), 0.0));
909
910 let mut optimizer = VariationalOptimizer::new(0.1, 0.9);
911
912 let mut gradients = FxHashMap::default();
914 gradients.insert("theta".to_string(), 1.0);
915
916 optimizer.step(&mut circuit, &gradients).unwrap();
918
919 let params = circuit.get_parameters();
920 assert!(params.get("theta").unwrap().abs() > 0.0);
921 }
922}